Home Artificial Intelligence Decoding Strategies in Large Language Models πŸ“š Background πŸƒβ€β™‚οΈ Greedy Search βš–οΈ Beam Search 🎲 Top-k sampling πŸ”¬ Nucleus sampling Conclusion

Decoding Strategies in Large Language Models πŸ“š Background πŸƒβ€β™‚οΈ Greedy Search βš–οΈ Beam Search 🎲 Top-k sampling πŸ”¬ Nucleus sampling Conclusion

0
Decoding Strategies in Large Language Models
πŸ“š Background
πŸƒβ€β™‚οΈ Greedy Search
βš–οΈ Beam Search
🎲 Top-k sampling
πŸ”¬ Nucleus sampling
Conclusion

The tokenizer, Byte-Pair Encoding on this instance, translates each token within the input text right into a corresponding token ID. Then, GPT-2 uses these token IDs as input and tries to predict the following almost certainly token. Finally, the model generates logits, that are converted into probabilities using a softmax function.

For instance, the model assigns a probability of 17% to the token for β€œof” being the following token after β€œI even have a dream”. This output essentially represents a ranked list of potential next tokens within the sequence. More formally, we denote this probability as P(of | I even have a dream) = 17%.

Autoregressive models like GPT predict the following token in a sequence based on the preceding tokens. Consider a sequence of tokens w = (w₁, wβ‚‚, …, wβ‚œ). The joint probability of this sequence P(w) will be broken down as:

For every token wα΅’ within the sequence, P(wα΅’ | w₁, wβ‚‚, …, wᡒ₋₁) represents the conditional probability of wα΅’ given all of the preceding tokens (w₁, wβ‚‚, …, wᡒ₋₁). GPT-2 calculates this conditional probability for every of the 50,257 tokens in its vocabulary.

This results in the query: how will we use these probabilities to generate text? That is where decoding strategies, similar to greedy search and beam search, come into play.

Greedy search is a decoding method that takes probably the most probable token at each step as the following token within the sequence. To place it simply, it only retains the almost certainly token at each stage, discarding all other potential options. Using our example:

  • Step 1: Input: β€œI even have a dream” β†’ Most certainly token: β€œ of”
  • Step 2: Input: β€œI even have a dream of” β†’ Most certainly token: β€œ being”
  • Step 3: Input: β€œI even have a dream of being” β†’ Most certainly token: β€œ a”
  • Step 4: Input: β€œI even have a dream of being a” β†’ Most certainly token: β€œ doctor”
  • Step 5: Input: β€œI even have a dream of being a health care provider” β†’ Most certainly token: β€œ.”

While this approach might sound intuitive, it’s vital to notice that the greedy search is short-sighted: it only considers probably the most probable token at each step without considering the general effect on the sequence. This property makes it fast and efficient because it doesn’t must keep track of multiple sequences, nevertheless it also implies that it could actually miss out on higher sequences that may need appeared with barely less probable next tokens.

Next, let’s illustrate the greedy search implementation using graphviz and networkx. We select the ID with the very best rating, compute its log probability (we take the log to simplify calculations), and add it to the tree. We’ll repeat this process for five tokens.

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import time

def get_log_prob(logits, token_id):
# Compute the softmax of the logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
log_probabilities = torch.log(probabilities)

# Get the log probability of the token
token_log_probability = log_probabilities[token_id].item()
return token_log_probability

def greedy_search(input_ids, node, length=5):
if length == 0:
return input_ids

outputs = model(input_ids)
predictions = outputs.logits

# Get the expected next sub-word (here we use top-k search)
logits = predictions[0, -1, :]
token_id = torch.argmax(logits).unsqueeze(0)

# Compute the rating of the expected token
token_score = get_log_prob(logits, token_id)

# Add the expected token to the list of input ids
new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)

# Add node and edge to graph
next_token = tokenizer.decode(token_id, skip_special_tokens=True)
current_node = list(graph.successors(node))[0]
graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
graph.nodes[current_node]['token'] = next_token + f"_{length}"

# Recursive call
input_ids = greedy_search(new_input_ids, current_node, length-1)

return input_ids

# Parameters
length = 5
beams = 1

# Create a balanced tree with height 'length'
graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())

# Add 'tokenscore', 'cumscore', and 'token' attributes to every node
for node in graph.nodes:
graph.nodes[node]['tokenscore'] = 100
graph.nodes[node]['token'] = text

# Start generating text
output_ids = greedy_search(input_ids, 0, length=length)
output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)
print(f"Generated text: {output}")

Generated text: I even have a dream of being a health care provider.

Our greedy search generates the identical text because the one from the transformers library: β€œI even have a dream of being a health care provider.” Let’s visualize the tree we created.

import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.colours as mcolors
from matplotlib.colours import LinearSegmentedColormap

def plot_graph(graph, length, beams, rating):
fig, ax = plt.subplots(figsize=(3+1.2*beams**length, max(5, 2+length)), dpi=300, facecolor='white')

# Create positions for every node
pos = nx.nx_agraph.graphviz_layout(graph, prog="dot")

# Normalize the colours along the range of token scores
if rating == 'token':
scores = [data['tokenscore'] for _, data in graph.nodes(data=True) if data['token'] will not be None]
elif rating == 'sequence':
scores = [data['sequencescore'] for _, data in graph.nodes(data=True) if data['token'] will not be None]
vmin = min(scores)
vmax = max(scores)
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
cmap = LinearSegmentedColormap.from_list('rg', ["r", "y", "g"], N=256)

# Draw the nodes
nx.draw_networkx_nodes(graph, pos, node_size=2000, node_shape='o', alpha=1, linewidths=4,
node_color=scores, cmap=cmap)

# Draw the perimeters
nx.draw_networkx_edges(graph, pos)

# Draw the labels
if rating == 'token':
labels = {node: data['token'].split('_')[0] + f"n{data['tokenscore']:.2f}%" for node, data in graph.nodes(data=True) if data['token'] will not be None}
elif rating == 'sequence':
labels = {node: data['token'].split('_')[0] + f"n{data['sequencescore']:.2f}" for node, data in graph.nodes(data=True) if data['token'] will not be None}
nx.draw_networkx_labels(graph, pos, labels=labels, font_size=10)
plt.box(False)

# Add a colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
if rating == 'token':
fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Token probability (%)')
elif rating == 'sequence':
fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Sequence rating')
plt.show()

# Plot graph
plot_graph(graph, length, 1.5, 'token')

LEAVE A REPLY

Please enter your comment!
Please enter your name here