Home Artificial Intelligence The Current State of Continual Learning in AI Intro Continual Learning within the pre-training vs fine-tuning stages of language models The 5 sub-categories of continual learning techniques 1. Regularisation-based approaches Soft-masking applied to continual pre-training in a language model Initial Pre-training Phase Further Pre-training on A Latest Domain 2. Optimisation-based approaches 3. Representation-based approach Training process step-by-step 4. Replay-based approach 5. Architecture-based approach Conclusion

The Current State of Continual Learning in AI Intro Continual Learning within the pre-training vs fine-tuning stages of language models The 5 sub-categories of continual learning techniques 1. Regularisation-based approaches Soft-masking applied to continual pre-training in a language model Initial Pre-training Phase Further Pre-training on A Latest Domain 2. Optimisation-based approaches 3. Representation-based approach Training process step-by-step 4. Replay-based approach 5. Architecture-based approach Conclusion

0
The Current State of Continual Learning in AI
Intro
Continual Learning within the pre-training vs fine-tuning stages of language models
The 5 sub-categories of continual learning techniques
1. Regularisation-based approaches
Soft-masking applied to continual pre-training in a language model
Initial Pre-training Phase
Further Pre-training on A Latest Domain
2. Optimisation-based approaches
3. Representation-based approach
Training process step-by-step
4. Replay-based approach
5. Architecture-based approach
Conclusion

The excellent overview of continual learning paper states training strategies for continual learning may be divided into 5 sub categories:

  1. Regularisation-based approach: this approach adds constraints or penalties to the educational process through the training process.
  2. Optimisation-based approach: this method focuses on modifying the optimisation algorithm.
  3. Representation-based approach: this goals to learn a shared feature representation across different tasks, helping the model generalise higher to recent but related tasks.
  4. Replay-based approach: this involves storing some data or learned features from previous tasks and replaying them during training on recent tasks to keep up performance on earlier learned tasks. In other words, mixing each the old and recent datasets when training on recent tasks.
  5. Architecture-based approach: on this approach, the network architecture is dynamically adjusted, often by growing or partitioning, delegating different parts of the network to different tasks.

Soft Masking of Parameters

The next soft-masking techniques mask and adjust the gradients of every parameter through the training process. The optimisation-based approaches coming up also use the gradients for continual learning. Remember the gradients aren’t just temporary numbers that appear and disappear during training; they’re signals that guide the evolution of the weights.

SPG

This paper proposes a method named SPG (Soft-masking of Parameter-level Gradient flow) which goals to:

  1. Train the model on each task until convergence.
  2. After training, calculate the “importance” of every parameter for the duty.
  3. Soft-mask parameters based on their gathered importance, making necessary parameters less more likely to change through the learning of recent tasks.

Let’s break the approach down step-by-step:

1. Training the First Task

Train the model on the primary task’s dataset as normal.

2. Calculate Parameter Importance for the First Task

After the training of the primary task is complete, we calculate the importance of every model parameter. The intuition here is easy, we use the gradients of every parameter to compute its importance. A bigger gradient implies that a small change in that parameter will lead to a bigger change within the loss, meaning the model’s performance could vary more significantly, hence that parameter is very important.

The gradients are also normalised, because gradients in the primary layer could possibly be small, while those within the last layer could possibly be large. In the event you’re calculating importance based on these raw gradient values, parameters within the last layer would appear more necessary due to the size of their gradients, not necessarily because they’re genuinely more crucial for the duty.

Equations for calculating the importance of the model parameters in SPG (section 3.1 of paper)

Let’s translate this calculation to PyTorch-like pseudocode:

import torch

def compute_final_importance(model, loss_function, data_loader):
# Get a single batch from the information loader
inputs, labels = next(iter(data_loader))

# Forward and backward pass to calculate the gradients for all parameters
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()

importances = []

# Calculate importance based on the gradients
for param in model.parameters():
if param.grad is just not None: # Gradients could also be None for some unused parameters
normalized_grad = (param.grad - torch.mean(param.grad)) / torch.std(param.grad)
importance = torch.tanh(normalized_grad)
importances.append(importance)

return torch.stack(importances).mean(dim=0)

3. Accumulating Importance Across Tasks

The gathered importance of every parameter across task is solely calculated by taking the max value at any stage.

4. Training Subsequent Tasks, combined loss and the soft-masking mechanism:

When training on recent tasks, the researchers use a combined loss function consisting of two parts. One is the usual loss function which is used as normal on the brand new task and data, and the second is a further loss function which involves putting the recent data through the old model (the converged model checkpoint after the previous task) and summing up the logits produced. In classification networks the logits are often the raw non normalised predictions generated by the model in one among the last layers before going through something like a softmax function. This sum of logits serves as a type of loss. The rationale is that if the summed logits are significantly affected when the model parameters change, those parameters are crucial for the performance of the previously learned task.

The gradients generated from this extra loss function a guide during backpropagation, nudging the shared parameters to vary in a direction that’s less more likely to harm performance on the primary task. It subsequently acts as a form of penalty term to implement that any updates made to the model don’t result in a major loss of data related to previous tasks.

Train the model on the following task. Use a regular training loop, but modify the gradients during backpropagation based on their gathered importance. That is the soft-masking mechanism:

import torch

accumulated_importance = # calculated at the top of every task

for epoch in range(num_epochs):
for x, y in train_loader:

# Forward Pass: Calculate the loss for the present task using the correct loss function
logits = new_model(x)
loss_current_task = nn.CrossEntropyLoss()(logits, y)

# Forward Pass: Calculate the extra losses for previous tasks (CHI mechanism)
loss_previous_tasks = 0
for prev_task_id in range(task_id):
logits_prev = old_model(x, prev_task_id)
loss_previous_tasks += logits_prev.sum()

# Mix the losses
combined_loss = loss_current_task + loss_previous_tasks

# Backward Pass
optimizer.zero_grad()
combined_loss.backward()

# Update the gathered importance
for param, acc_imp in zip(model.parameters(), accumulated_importance):
grad = param.grad
acc_imp = torch.max(acc_imp, torch.abs(grad))

# Soft-masking the gradients before taking an optimization step
for param, imp in zip(model.parameters(), accumulated_importance):
param.grad *= (1 - importance)

optimizer.step()

5. Soft-Masking Special Cases

  • Feature Extractor: Gradients of parameters within the shared feature extractor are modified based on their specific gathered importance.
  • Classification Head: For the classification head, gradients are modified based on the typical importance of the feature extractor.

Applying this to LLMs

Keep in mind, this paper doesn’t experiment this with a language model, but I assume in a language model you would consider the transformer layers as analogous to the “feature extractor,” and the ultimate classification layer (which predicts the following word or token within the sequence) because the “classification head.”

Next we’ll go right into a paper which applies similar soft-masking to the pre-training stage in language modelling.

This paper introduces a method called DAS (Continual DA-pre-training of LMs with Soft-masking) for continual learning within the pre-training stage of a giant language model. It applies a soft-masking technique much like the one just discussed together with a pair other techniques in try to proceed pre-training of an LLM without running into catastrophic forgetting.

Let’s break it down step-by-step:

Pre-train the LLM like normal.

Prepare Latest Domain Data:

A brand new dataset from a unique domain is ready.

Calculating the importance of every neuron

SPG used gradients to find out the importance of every parameter, after which applied the calculated importance value to mask the gradient adjustments of parameters during training. This paper tries to find out the importance of every unit/neuron, slightly than parameter, after which uses this in the identical way by masking the gradient during training.

This paper uses two different methods to calculate the importance of neurons, depending on the duty at hand. One, a gradient-based importance detection method (originally outlined on this paper), and two, a custom “proxy loss function”.

The primary introduced is not used in any respect within the continual learning of the first recent domain. Why? It needs data from the training dataset to work and the authors state that users “don’t have access to the huge original pre-training dataset”, which is a good assumption.

They propose a Proxy Loss Function:

I discovered this term confusing at first, nevertheless it’s called this because the unique gradient-based importance detection method is defined as a loss function itself, which you’ll be able to then use to run the network’s outputs through to get the gradients of every neuron, which might then be used to derive importance, similar to the SPG technique.

In line with the paper, the importance is calculated for every “unit” within the network, where a unit could possibly be a neuron or an attention head.

Proxy loss function (“Proxy KL-divergence loss”):

  • Take a subset of the brand new domain we’re wanting to coach on and feed it twice through the model to get two different representations. These representations will differ a bit as a consequence of the prevailing dropout masks within the Transformer architecture.
  • Compute the KL-divergence between these two representations.

Modified Backpropagation Flow with Proxy and Combined Loss

  1. Forward Pass: Data goes through a forward pass within the neural network.
  2. Backpropagation:

Apply Proxy Loss for Gradient Adjustment: The proxy loss function’s unit-level importance is used to soft-mask the unique gradients. That is expressed as:

adjusted_grad *= (1 − unit_level_importance)

Calculate Combined Loss (MLM + Contrastive Loss): Compute the combined loss using each MLM and contrastive loss.

Further Pre-training on More Domains

  1. Direct Importance Calculation: For every recent domain, the importance of every unit can now be directly calculated using the information from the brand new domain via the gradient-based method outlined in equation 3, eliminating the necessity for the proxy loss function which is barely once used after the initial pre-training.
  2. The importance of neurons is updated incrementally as each recent task is learned. This update is completed using element-wise max. “Element-wise maximum (EMax) operation” refers to comparing two vectors element by element, and taking the utmost value for every corresponding element to create a brand new vector. E.g.: if you’ve gotten two vectors A and B of the identical length, the element-wise maximum will lead to a brand new vector C where each element C[i] is the utmost between A[i] and B[i].

We’ll discuss with the 2 techniques outlined in the great survey paper in section 3.1

Gradient Direction Preservation

The paper talks about manipulating the gradient-based optimisation process to make the gradient directions of recent training samples near those from old training samples. The formula

⟨ ∇θ Lₖ(θ; Dₖ), ∇θ Lₖ(θ; Mₜ) ⟩ ≥ 0

enforces that learning the brand new task mustn’t increase the loss for the old tasks. Essentially, the gradients of the brand new task and the old tasks are encouraged to align.

Breaking down the formula, we take the dot product of the gradient of the loss from the brand new task (∇θ Lₖ(θ; Dₖ)) and the gradient of the loss from the old task (∇θ Lₖ(θ; Mₜ)) needs to be non-negative. On this context, a positive dot product implies that the gradients for the old task and the brand new task are generally pointing in the identical direction, with the angle between these two vectors is lower than or equal to 90 degrees.

Forward/Backward Passes:

Forward Pass:

You’d run your input data Dₖ for the brand new task and Mₜ​ for the old task through the identical model to calculate the loss for every.

Backward Pass:

  1. Compute the gradients of the loss with respect to the network parameters for each the old and recent task.
  2. Alignment Check: Compute the dot product of the 2 gradients. You’d then use this information to switch the gradients for the brand new task in such a way that the dot product is non-negative.
  3. Update Weights: Update the model parameters using these “aligned” gradients.

import torch

# Forward pass for the brand new task
output_k = model(D_k)
loss_k = criterion(output_k, y_k)

# Forward pass for the old task
output_t = model(M_t)
loss_t = criterion(output_t, y_t)

# Compute gradients for each tasks
loss_k.backward(retain_graph=True) # Compute gradients for brand new task but keep computation graph
grad_k = torch.cat([p.grad.view(-1) for p in model.parameters()])

optimizer.zero_grad()

loss_t.backward() # Compute gradients for old task
grad_t = torch.cat([p.grad.view(-1) for p in model.parameters()])

# Compute dot product and modify gradients in the event that they don't align
dot_product = torch.dot(grad_k, grad_t)
if dot_product < 0:
# I'm unsure the way you modify the gradients here in the event that they don't align, I'm unsure the paper specifies it

# Use the modified gradient to update model parameters
index = 0
for p in model.parameters():
num_params = p.numel()
# Update using modified gradients
p.grad = grad_k[index: index + num_params].view(p.shape)
index += num_params

optimizer.step()

Gradient Direction Preservation with no need old training samples

The text also highlights that gradient projection may be performed even without storing old samples. NCL (Natural continual learning, paper link) is the technique summarised here. Note, this may be categorised as each a regularisation and optimisation based approach.

Training process step-by-step:

Forward Pass:

You’d run your recent data through the network and calculate the loss as usual.

Backward Pass:

Objective: The aim is to minimise the task-specific loss ℓk(θ) while adhering to a distance constraint d(θ,θ+δ)≤r.

Algorithm step-by-step:

  1. As normal, compute the gradient of the loss with respect to the model parameters ∇θ​ℓk​(θ).
  2. The δ is calculated using the update rule. This provides you the “suggested” changes to the model parameters θ based on the brand new task’s requirements.
  3. Then, you plug this δ into the gap constraint formula: d(θ,θ+δ)=squareroot(δ⊤Λ_k-1​δ)​. The constraint acts like a boundary around the present parameters θ, defined by the gap metric d(θ,θ+δ) and the radius r. I struggled to see why they called it a “radius”, and not only “constraint number” or something. I believe it’s since the researchers are visualising the gradients and training process in a high-dimensional space. While you apply a constraint based on the gap metric, you’re essentially defining a “sphere” around your current parameter values in that high-dimensional space. The “radius” r of this sphere sets a limit on how much the parameter can move while learning a brand new task.
  4. If the proposed δ would move θ too far based on this distance metric, i.e., beyond this boundary, you scale it down in order that it stays throughout the allowable region defined by the radius r.

Let’s take a look at each bit more in-depth:

Update Rule: The update rule provides a direction by which θ should move.

NCL update rule from section 3.1 in the great overview of continual learning paper

Breaking it down:

  • ∇θ ℓk(θ) represents the gradients for all parameters (θ) calculated by the loss function.
  • Parameter importance calculation (Λ^(k-1)_(-1)): This term represents a precision matrix and it’s one more technique to calculate the importance of parameters within the network. more details below
  • Regularisation Term (θ — μ_(k-1)): This term pulls the updated parameters closer to the optimal parameters μ_(k-1)​ from the previous task. Just like the before techniques, it acts as a regulariser to avoid deviation from what was already learned.
  • Learning Rate (λ)

Distance Constraint: Before applying this update, you’d often check whether this alteration δ would violate the gap constraint d(θ,θ+δ)≤r. If it does, you’d typically scale down δ in order that it satisfies the constraint.

Precision matrix explanation: before within the soft-masking methods we saw the calculation of importance via the output of all neurons or their gradients. On this method a precision matrix is used. This can be a bit more complex so I’ll attempt to elucidate it:

We first calculate the covariance matrix for the networks parameters. Within the context of neural networks, the columns within the gradient matrix G correspond to the parameters (weights and biases) of the model. Each row in G represents the gradient vector for a single training example, with respect to all of those parameters.

So, if you’ve gotten a neural network with P parameters (this includes all of the weights and biases from all layers), then each gradient vector can have P elements, one for every parameter. Due to this fact, G shall be a matrix of shape N × P, N representing each batch and subsequently each row representing the typical gradient vector across all of the training examples in a given batch.

While you calculate the covariance matrix Σ from G, the resulting matrix can have dimensions P × P. The diagonal entries Σii​ will indicate the variance of the gradient with respect to the ith parameter, and the off-diagonal entries Σij​ will indicate the covariance between the gradients with respect to the ith and jth parameters. This provides you an idea of how these parameters interact or co-vary through the training process. The inverse of this matrix is the precision matrix, which is what we use to find out importance.

Why the precision matrix over the covariance matrix? While the covariance matrix Σ does capture how parameters interact with one another during training, it doesn’t specifically indicate how crucial each parameter is to the duty at hand when all other parameters are considered. In contrast, the precision matrix allows us to evaluate the conditional independence (it is a concept in probability theory, look it up) of parameters. Large values within the precision matrix indicate that knowing one parameter is extremely informative about one other, given all the opposite parameters. I’m not going to enter examples of how this works so get ChatGPT to generate some examples using a really small neural network to see how the values may be interpreted.

Previous methods we saw that calculate importance concentrate on individual neurons or parameters, ignoring the relationships between them. The precision matrix, however, can capture these relationships. Like every part in deep learning, whether it is a higher technique to calculate the importance of a network, goes to be empirical and will differ depending on the duty and scale of the network.

Algorithm step-by-step in PyTorch:

import torch

# Constraint radius
radius = 0.1

for epoch in range(num_epochs):
for batch_idx, (data, goal) in enumerate(data_loader):
optimizer.zero_grad()

# Forward pass
output = model(data)
loss = loss_function(output, goal)

# Backward pass to get gradients for params
loss.backward()
model_grad = torch.cat([p.grad.data.view(-1) for p in model.parameters()])

# Compute δ using the NCL method
# δ = Λ^(-1) * grad - (θ - µ)
delta = torch.matmul(torch.inverse(covarianceMatrix), model_grad) - (torch.cat([p.data.view(-1) for p in model.parameters()]) - parametersForPrevTask)

# Check constraint
if torch.norm(delta) > radius:
delta = radius * delta / torch.norm(delta)

# Update model parameters (θ) using δ
idx = 0
for p in model.parameters():
length = p.data.numel()
p.data += delta[idx: idx + length].view(p.data.shape)
idx += length

# Update Λ and µ for the following task, probably going to be task-specific and non-trivial

LEAVE A REPLY

Please enter your comment!
Please enter your name here