Home Artificial Intelligence Clone the Abilities of Powerful LLMs into Small Local Models Using Knowledge Distillation

Clone the Abilities of Powerful LLMs into Small Local Models Using Knowledge Distillation

0
Clone the Abilities of Powerful LLMs into Small Local Models Using Knowledge Distillation

Boost the performance of local LLMs using supervision from larger one

Towards Data Science
Photo by matthew Feeney on Unsplash

Within the realm of Natural Language Processing (NLP), cutting-edge Large Language Models (LLMs) offer remarkable few-shot learning and reasoning capabilities. Nonetheless, the computational demands and latency related to these models can sometimes render them impractical for certain applications. In case your goal, for example, is to develop a translation service, you most likely don’t require your back-end LLM to own the flexibility to crack jokes or explain quantum physics to a kindergartner. This highlights the demand for specialised, smaller-scale models.

A viable solution to this challenge is to construct tailored LLMs that cater precisely to your specific use case. This involves annotating significant volumes of knowledge after which fine-tuning a more compact model like Tiny-llama to fit your requirements. Such an approach not only ensures that the model aligns closely along with your needs but in addition mitigates the computational and deployment expenses related to larger LLMs. Nonetheless, one must acknowledge the downside of this method: the technique of data annotation is commonly laborious and time-consuming.

To deal with this bottleneck, an alternate emerges in the shape of information distillation. As an alternative of relying solely on manual labeling, this approach leverages the capabilities of a really large language model together with targeted prompting to generate labeled data routinely. Subsequently, a smaller model may be fine-tuned using this distilled knowledge, thereby streamlining the model development process while maintaining performance.

On this post, we’ll work trough this very same scenario applied to constructing a model for multi-language grammatical error correction.

The Task:

Our goal is to detect and proper grammatical errors inside a sentence. As an illustration:

  • Corrupted sentence: “It is extremely hard to do away with bad habit.”
  • Corrected sentence: “It is extremely hard to do away with bad habits.”

The Distillation Workflow:

Here’s how we’re going to distill the knowledge from our teacher model to our student model:

  1. First, acquire unlabeled in-domain data.
  2. Second, craft a prompt to extract pseudo-labels from the teacher model by leveraging Anyscale’s API.
  3. Finally, fine-tune the scholar model on these pseudo labels using LoRa + Peft.

The Data:

The info we use is from huggingface datasets “`juancavallotti/multilingual-gec““ where we only use the labels for evaluation and never for training. [Licensed under Apache 2]

This data may be loaded as follows:

from datasets import load_dataset

data = load_dataset("juancavallotti/multilingual-gec", split="train")

The Teacher Model:

We’re employing the LLama 2–70B as our teacher model. The teacher model is what is going to produce the pseudo-labels that will probably be used for the training. This powerful LLM is hosted on AnyScale’s pay-per-use API. AnyScale offers a $10 credit, allowing you to explore and utilize the model without incurring any costs initially. In its place you may also use OpenAI or Anthropic’s API.

We generate pseudo-labels for around 5000 samples. It costs 1.2 dollars.

You may call this API like this:

from openai import OpenAI

BASE_URL = "https://api.endpoints.anyscale.com/v1"
BASE_MODEL = "meta-llama/Llama-2-70b-chat-hf"

BASE_CLIENT = OpenAI(base_url=BASE_URL, api_key=API_KEY)

def process_call(prompt):

completion = BASE_CLIENT.completions.create(
model=BASE_MODEL,
prompt=prompt,
max_tokens=100,
temperature=0,
)
result = completion.model_dump()

return result["choices"][0]["text"].strip()

We use an easy few-shot prompting technique using the LLama 2 prompt template. This enables the LLM to know what’s the expected output and customarily improves the standard of the result.

[INST]
Your role is to correct all grammatical errors within the input text. Only answer with the corrected text and nothing else.

Text: Il est très importante de parler une langue étrangère.
[/INST]
Output: Il est très essential de parler une langue étrangère.

[INST]
Text: Nadie dise ezo.
[/INST]
Output: Nadie dice eso.
[INST]
Text: What's your favorite a part of being a member of SWE RMS?
[/INST]
Output: What's your favorite a part of being a member of SWE RMS?
[INST]
Text: I looked, on the schedule.
[/INST]
Output: I checked out the schedule.
[INST]
Text: $text
[/INST]
Output:

The Student Model:

We’re using Tiny-LLama as our student model. The scholar model is what we’ll “train” on the grammar correction task using the pseudo-labels from the teacher model. Despite its smaller scale with 1 billion parameters, it’s highly efficient. Tiny-LLama can run on consumer GPUs with just a couple of gigabytes of memory.

This model may be run as a HuggingFace Pipeline. We use BitsAndBytes for GPU quantization, this reduces the memory requirements of running LLMs.

from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline,
)

base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

llama_tokenizer = AutoTokenizer.from_pretrained(
base_model_name, trust_remote_code=True
)
llama_tokenizer.padding_side = "right"

quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
)
# Model
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=quant_config,
device_map={"": 0},
)

text_gen = pipeline(
task="text-generation",
model=model,
tokenizer=llama_tokenizer,
max_new_tokens=256,
do_sample=False,
return_full_text=False,
)

print(text_gen("Hello ! Who're you ?"))

You must get something like this within the output:

[{'generated_text': ' I am a writer, a poet, a musician, a dancer, a painter, a sculptor, a filmmaker, a photographer, a cartoonist, a journalist, a teacher, a student, a lover, a friend, a stranger, a human being, a cat, a dog, a bird, a tree, a rock, a sandstone, a mineral, a fossil, a plant, a fungus, a bacterium, a virus, a microbe, a parasite, a symbiosis, a symphony, a symmetry, a chaos, a harmony, a balance, a balance of forces, a balance of energies, a balance of opposites, a balance of opposing forces, a balance of opposing principles, a balance of opposing ideas, a balance of opposing emotions, a balance of opposing thoughts, a balance of opposing desires, a balance of opposing needs, a balance of opposing needs, a balance of opposing desires, a balance of opposing emotions, a balance of opposing principles, a balance of opposing forces, a balance of opposing energies, a balance of opposing symb'}]

We can even fine-tune it using HuggingFace libraries: PEFT and TRL. PEFT stands for “Parameter-Efficient High quality-Tuning” and it implements several types of low-rank adapter LLM fine-tuning methods. TRL stands for “Transformer Reinforcement Learning” and implements general fine-tuning workflows.
You may read all about it here: https://huggingface.co/docs/trl/major/en/lora_tuning_peft

The implementation uses QLoRa, an approach that’s in a position to fine-tune adapter weights of a quantized version of the total model. This enables us to run the training with around 3Gb of VRam using a mini-batch size of 8 which makes it possible to run in most consumer grade GPUs.

LoRa are additive low rank adapter weights which might be trained while freezing the backbone. It allows to construct specialized models that may be trained with a much smaller VRam and disk space footprint. In our case, the weights are only 4.5 MB and include around a million parameters.
Here is the pseudo-code that shows how it really works, full code is linked at the top of the post:

import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from trl import SFTTrainer

if __name__ == "__main__":
.
.
.
.
peft_parameters = LoraConfig(
lora_alpha=8,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
# target_modules=target_modules,
)

base_model = prepare_model_for_kbit_training(base_model)
base_model = get_peft_model(base_model, peft_parameters)

# Training Params
train_params = TrainingArguments(
output_dir=str(BASE_PATH / "results_modified"),
num_train_epochs=EPOCHS,
per_device_train_batch_size=8,
gradient_accumulation_steps=1,
optim="paged_adamw_32bit",
save_steps=len(training_data) // 10,
logging_steps=len(training_data) // 100,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_steps=100,
weight_decay=0.05,
fp16=True,
max_steps=-1,
group_by_length=False,
max_grad_norm=0.3,
)
# Trainer
fine_tuning = SFTTrainer(
model=base_model,
train_dataset=training_data,
data_collator=collator,
peft_config=peft_parameters,
dataset_text_field="Why is that this mandatory ?",
tokenizer=llama_tokenizer,
args=train_params,
max_seq_length=llama_tokenizer.model_max_length,
)

print(fine_tuning.model.print_trainable_parameters())
# Training
fine_tuning.train()

The outcomes:

To judge whether or not this whole workflow works or not we will take a look at few outputs of the bottom Tiny-LLama versus the version distilled from LLama 2–70B’s output. So let’s see:

Example 1:

Corrupted input:
* We dont live in Australia Were just visiting
Base model output:
* We don’t live in Australia, We’re just visiting.
Distilled model output:
* We don’t live in Australia. We are only visiting.

Here the bottom model fixed a number of the issues but tousled the punctuation.

Example 2:

Corrupted input:
* Je ai été surprise.
Base model output:
* I used to be surprised.
Distilled model output:
* J’ai été surprise.

Here the bottom model fixed the sentence but created an output in English as an alternative of in the unique french while the distilled model fixed it in French.

We can even compute the fraction of cases where the output of the model matches exactly with expected output. This metric is flawed as there may be multiple ways a sentence may be fixed (“It is extremely hard to do away with bad habit.” may be corrected as “It is extremely hard to do away with bad habits.” or “It is extremely hard to do away with a nasty habit.”) but it may function a very good proxy of the standard of generation. We get the next scores:

LLama 2–70B: 42%
Base Tiny-LLama: 11%
Distilled Tiny-LLama: 31%

While we’re still removed from the performance of the teacher model, we were in a position to significantly improve the performance of the scholar model from 11% to 31%. The gap from 31% to 42% may be bridged by either using a bigger distillation dataset or a much bigger student model.

Conclusion:

By distilling knowledge from a high-capacity teacher model, resembling the LLama 2–70B, to a more compact student model like Tiny-LLama, we navigate the trade-offs between computational efficiency and task-specific accuracy. This process involves crafting prompts, acquiring unlabeled in-domain data, and fine-tuning the scholar model using pseudo-labels generated by the teacher model. This approach mitigates the computational and deployment expenses related to larger LLMs.

The implementation showcased here, specializing in multi-language grammatical error correction, underscores the practicality and effectiveness of information distillation. Despite the laborious and time-consuming nature of knowledge annotation, distillation techniques offer a scalable solution by automating the generation of labeled data through targeted prompting. Furthermore, advancements in model quantization and training methodologies, resembling QLoRa and PeFt, further optimize the training of specialised models on consumer-grade GPUs.

Evaluation results reveal a notable improvement within the performance of the scholar model, transitioning from 11% accuracy to 31% exact match rating, albeit still below the benchmark set by the teacher model at 42%. Nonetheless, this progress underscores the efficacy of distillation techniques in bridging the gap between computational efficiency and task-specific accuracy.

Code: https://github.com/CVxTz/distill-llm

LEAVE A REPLY

Please enter your comment!
Please enter your name here