The tokenizer, Byte-Pair Encoding on this instance, translates each token within the input text right into a corresponding token ID. Then, GPT-2 uses these token IDs as input and tries to predict the following almost certainly token. Finally, the model generates logits, that are converted into probabilities using a softmax function.
For instance, the model assigns a probability of 17% to the token for βofβ being the following token after βI even have a dreamβ. This output essentially represents a ranked list of potential next tokens within the sequence. More formally, we denote this probability as P(of | I even have a dream) = 17%.
Autoregressive models like GPT predict the following token in a sequence based on the preceding tokens. Consider a sequence of tokens w = (wβ, wβ, β¦, wβ). The joint probability of this sequence P(w) will be broken down as:
For every token wα΅’ within the sequence, P(wα΅’ | wβ, wβ, β¦, wα΅’ββ) represents the conditional probability of wα΅’ given all of the preceding tokens (wβ, wβ, β¦, wα΅’ββ). GPT-2 calculates this conditional probability for every of the 50,257 tokens in its vocabulary.
This results in the query: how will we use these probabilities to generate text? That is where decoding strategies, similar to greedy search and beam search, come into play.
Greedy search is a decoding method that takes probably the most probable token at each step as the following token within the sequence. To place it simply, it only retains the almost certainly token at each stage, discarding all other potential options. Using our example:
- Step 1: Input: βI even have a dreamβ β Most certainly token: β ofβ
- Step 2: Input: βI even have a dream ofβ β Most certainly token: β beingβ
- Step 3: Input: βI even have a dream of beingβ β Most certainly token: β aβ
- Step 4: Input: βI even have a dream of being aβ β Most certainly token: β doctorβ
- Step 5: Input: βI even have a dream of being a health care providerβ β Most certainly token: β.β
While this approach might sound intuitive, itβs vital to notice that the greedy search is short-sighted: it only considers probably the most probable token at each step without considering the general effect on the sequence. This property makes it fast and efficient because it doesnβt must keep track of multiple sequences, nevertheless it also implies that it could actually miss out on higher sequences that may need appeared with barely less probable next tokens.
Next, letβs illustrate the greedy search implementation using graphviz and networkx. We select the ID with the very best rating, compute its log probability (we take the log to simplify calculations), and add it to the tree. Weβll repeat this process for five tokens.
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import timedef get_log_prob(logits, token_id):
# Compute the softmax of the logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
log_probabilities = torch.log(probabilities)
# Get the log probability of the token
token_log_probability = log_probabilities[token_id].item()
return token_log_probability
def greedy_search(input_ids, node, length=5):
if length == 0:
return input_ids
outputs = model(input_ids)
predictions = outputs.logits
# Get the expected next sub-word (here we use top-k search)
logits = predictions[0, -1, :]
token_id = torch.argmax(logits).unsqueeze(0)
# Compute the rating of the expected token
token_score = get_log_prob(logits, token_id)
# Add the expected token to the list of input ids
new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)
# Add node and edge to graph
next_token = tokenizer.decode(token_id, skip_special_tokens=True)
current_node = list(graph.successors(node))[0]
graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
graph.nodes[current_node]['token'] = next_token + f"_{length}"
# Recursive call
input_ids = greedy_search(new_input_ids, current_node, length-1)
return input_ids
# Parameters
length = 5
beams = 1
# Create a balanced tree with height 'length'
graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())
# Add 'tokenscore', 'cumscore', and 'token' attributes to every node
for node in graph.nodes:
graph.nodes[node]['tokenscore'] = 100
graph.nodes[node]['token'] = text
# Start generating text
output_ids = greedy_search(input_ids, 0, length=length)
output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)
print(f"Generated text: {output}")
Generated text: I even have a dream of being a health care provider.
Our greedy search generates the identical text because the one from the transformers library: βI even have a dream of being a health care provider.β Letβs visualize the tree we created.
import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.colours as mcolors
from matplotlib.colours import LinearSegmentedColormapdef plot_graph(graph, length, beams, rating):
fig, ax = plt.subplots(figsize=(3+1.2*beams**length, max(5, 2+length)), dpi=300, facecolor='white')
# Create positions for every node
pos = nx.nx_agraph.graphviz_layout(graph, prog="dot")
# Normalize the colours along the range of token scores
if rating == 'token':
scores = [data['tokenscore'] for _, data in graph.nodes(data=True) if data['token'] will not be None]
elif rating == 'sequence':
scores = [data['sequencescore'] for _, data in graph.nodes(data=True) if data['token'] will not be None]
vmin = min(scores)
vmax = max(scores)
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
cmap = LinearSegmentedColormap.from_list('rg', ["r", "y", "g"], N=256)
# Draw the nodes
nx.draw_networkx_nodes(graph, pos, node_size=2000, node_shape='o', alpha=1, linewidths=4,
node_color=scores, cmap=cmap)
# Draw the perimeters
nx.draw_networkx_edges(graph, pos)
# Draw the labels
if rating == 'token':
labels = {node: data['token'].split('_')[0] + f"n{data['tokenscore']:.2f}%" for node, data in graph.nodes(data=True) if data['token'] will not be None}
elif rating == 'sequence':
labels = {node: data['token'].split('_')[0] + f"n{data['sequencescore']:.2f}" for node, data in graph.nodes(data=True) if data['token'] will not be None}
nx.draw_networkx_labels(graph, pos, labels=labels, font_size=10)
plt.box(False)
# Add a colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
if rating == 'token':
fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Token probability (%)')
elif rating == 'sequence':
fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Sequence rating')
plt.show()
# Plot graph
plot_graph(graph, length, 1.5, 'token')