Training LLMs Taking Too Much Time? Technique you need to know to train it faster

Training and using LLMs can be a hefty task, By using this technique you can acheive more output with less time and performance.

Training LLMs Taking Too Much Time? Technique you need to know to train it faster

The Challenges of Training LLMs: Lots of Time and Resources

Suppose you want to train a Large Language Model(LLM), which can understand and produce human-like text. You want to input questions related to your organization and get answers from it.

The problem is that the LLM doesn't know your organization, It only knows general things. That is where applying methods to the models like Finetuning, RAG and many others comes up.

If we want to train Big LLMs, It requires a lot of resources and time. So it's a hefty task unless you have the proper machine to do the job.

Story of How We Solved The Problem of Time and Resources

Suppose we want to train the Llama 2 LLM based on the information of our organization, and we are using Google Colab to train it. The free version of Colab provides a single Nvidia T4 GPU, which provides 16GB of memory.

But for training the Llama 2 - 7 Billion Parameter model we require 28GB of memory.
This is a problem, We can't train the model with only 16GB of memory.

So to solve this, we tried researching into some optimization techniques and we found LoRA, Which stands for Low-Rank Adaptation of Large Language Models.

LoRA adds a layer of finetuning to the model, without modifying the existing model. This consumes less time and memory.

By using LoRA, I was able to finetune the Llama-2 Model and get the outputs from it from a single T4 GPU.

Refer to the above image. I asked the Llama2 model without finetuning a question, How many servers does Hexmos Have? It gave the reply that it is unable to provide the information.

After finetuning I asked the same question, and it gave me this reply
Hexmos has 2 servers in Azure and 4 servers in AWS

Let's see how LoRA helped me achieve this.

How LoRA Helps with Finetuning More Efficiently

Let's have a deeper dive into how LoRA works.

When training large models like GPT-3, it has 175 Billion Parameters. Parameters are like numbers that are stored in Matrices, it is like the knobs and dials that the model tweaks to get better at its task. Fully Finetuning them to our needs is a daunting task and requires a lot of computational resources.

LoRA, takes a different approach to this problem, Instead of fine-tuning the entire model, it focuses on modifying a smaller set of parameters.


Consider the above 2 boxes. One represents the weights for the existing model, the second one represents our fine-tuned weights(Based on our custom dataset). These are added together to form our fine-tuned model.

So by this method, We don't need to change the existing weights in the model. Instead, we add our fine-tuned weights on top of the original weights, this makes it less computationally expensive.

So another question may arise, how are these finetuned weights calculated?
In Matrices, we have a concept called Rank.

Rank, in simple words, determines the precision of the model after finetuning, If the Rank is low, There will be more optimization. But at the same time, you will be sacrificing the accuracy of the model.

If the Rank is high, the precision will be higher but there will be lesser optimization.

The LoRA weight matrix is calculated by multiplying 2 smaller matrices.

For example, we have to multiply 1x5 and 5x1 together to form a 5x5 LoRA weight matrix.

We can set the rank of the smaller matrix to determine the balance between precision and optimization.

1 powerful reason a day nudging you to read
so that you can read more, and level up in life.

Sent throughout the year. Absolutely FREE.

Real Life Example: Training a Llama2 Model with Custom Dataset

You can refer to the Google Colab for training a Llama2 Model on your own.

Here I will just give a brief overview of how we can train the LLM based on a custom dataset using LoRA

Step 1: Installing libraries

To train our LLama 2 model we will be using the following libraries

  • accelerate:- used to reduce the training time
  • peft :- This module offers the LoRA Technique
  • bitsandbytes :- Used for achieving smaller model size and faster processing speeds
  • transformers :- Library designed for working with transformer models
  • trl :- It is a library for training Large Language Models (LLMs)
Step 2: Prepare our data set

We need a dataset that matches the LLama2 Format,

This is how the LLama2 format looks like

<s>[INST] <<SYS>>
System prompt
<</SYS>>

User prompt [/INST] Model answer

For example, if we want to ask the model about the capital of France,

<s>[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>

What is the capital of France?[/INST]
The capital of france is Paris </s>

This is the llama2 Data Format in more detail

  • Start token (<s>) : This marks the beginning of each training instance.
  • Instruction token : ([INST]): This indicates the start of a specific instruction or prompt.
  • System prompt : (<<SYS>>): This represents the system's initial prompt or question.
  • User prompt : ([/INST]): This signifies the user's response or continuation of the dialogue.
  • Model answer : This represents the output generated by the LLM based on the prompt.
  • End token : (</s>): This marks the end of each training instance.

For trying out finetuning you can use guanaco-llama2-1k dataset which has the Llama2 format that can be used, or you can make your own by referring to it.

In our case, we had to make a dataset of our own.

For reference this is a brief idea of what our dataset looks like, Which i have made in a .jsonl file

{"text": "<s>[INST] How many servers does Hexmos have? [/INST] Hexmos has 2 Servers in Azure and 4 Servers in AWS </s>"}
{"text": "<s>[INST] What is the total server count for Hexmos? [/INST] Hexmos has 2 Servers in Azure and 4 Servers in AWS </s>"}
{"text": "<s>[INST] Can you tell me how many servers Hexmos manages in total? [/INST] Hexmos has 2 Servers in Azure and 4 Servers in AWS </s>"}
{"text": "<s>[INST] I'm curious about the number of servers Hexmos has deployed. Can you share that information? [/INST] Hexmos has 2 Servers in Azure and 4 Servers in AWS </s>"}
{"text": "<s>[INST] If I consider both Azure and AWS, how many servers does Hexmos have in total? [/INST] Hexmos has 2 Servers in Azure and 4 Servers in AWS </s>"}
{"text": "<s>[INST] Including cloud and on-premise servers, what is the total server count for Hexmos? [/INST] Hexmos has 2 Servers in Azure and 4 Servers in AWS </s>"}
{"text": "<s>[INST] How many servers does Hexmos have in Azure? [/INST] Hexmos has 2 Servers in Azure and 4 Servers in AWS </s>"}
{"text": "<s>[INST] Can you tell me the number of servers Hexmos has deployed in AWS? [/INST] Hexmos has 2 Servers in Azure and 4 Servers in AWS </s>"}
{"text": "<s>[INST] What is Hexmos' server capacity? [/INST] Hexmos has 2 Servers in Azure and 4 Servers in AWS </s>"}
{"text": "<s>[INST] How extensive is Hexmos' server infrastructure? [/INST] Hexmos has 2 Servers in Azure and 4 Servers in AWS </s>"}
Step 3: Loading the model and setting configuration

Here we are defining the model names

  • We will be using Llama-2-7b-chat-hf Model
  • After finetuning we will name it as Llama-2-7b-chat-finetune
# The model that you want to train from the Hugging Face hub
model_name = "NousResearch/Llama-2-7b-chat-hf"
# Fine-tuned model name
new_model = "Llama-2-7b-chat-finetune"

After that, we set the LoRA parameters

lora_r

Denotes the rank for the lower rank matrices which make up the LoRA weights.

We need to set a lower value if you want the model to learn from the patterns in the data, and a higher value if we want the model to gain new information from the dataset.

For example, for training with dataset like guanaco a low rank (Below 32) can be used. The aim of guanaco dataset is to teach the model how to behave as an assistant, so we only need the manner of speaking and not the specifics in the words used.

Whereas, for my case I want to teach the model about our organization. So using higher rank above 32 is better.

I have set the rank to 64. You can experiment with this value and see what works best.

lora_alpha

It defines the balance between using the knowledge in the pre-trained model and adapting it to a new task. A general rule of thumb is to set alpha = rank, We have set it as 64 here.

lora_dropout

To understand what is LoRA dropout rate, consider this example. Suppose you have 100 neurons in a layer of your machine-learning model. All 100 neurons are working on training on a specific dataset. This leads to a concept called overfitting, the model becomes too reliant on that training data. If something outside the dataset is given to the model, it won't perform well.

So to solve this problem, we introduce a concept called dropout. If we have a dropout percentage of 10%, during training 10% of the neurons are randomly dropped off. The remaining 90 neurons will train on the data and enable them to generalize the data.

We have set the the dropout rate as 0.1 (10%)

# LoRA attention dimension
lora_r = 64

# Alpha parameter for LoRA scaling
lora_alpha = 64

# Dropout probability for LoRA layers
lora_dropout = 0.1

Usually, models store their weights using 32-bit floating-point numbers. This allows for a high degree of accuracy but consumes a significant amount of memory.

So to fix this issue, we can use 4-bit precision, This technique reduces the number of bits used to represent each value from 32 to 4. This leads to a drastic reduction in memory.

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

These are some of the important code you'll need to know, The rest of them are provided in the notebook.

Step 4: Start the Finetuning process

Since we have the model ready, let's start loading the dataset with the configuration we defined earlier.

# Load dataset (you can process it here)

# Guanaco dataset
# dataset_name = "mlabonne/guanaco-llama2-1k"
# dataset = load_dataset(dataset_name, split="train")

# Your dataset
dataset_name = "loracustomdataset.jsonl"
dataset =  Dataset.from_json(dataset_name)

# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

We will load the base model from Hugging Face using AutoModelForCausalLM with our configuration defined earlier.

We use AutoModelForCausalLM because of 2 reasons

  • Causal Language Model: Consider this example, when you push the first domino, it topples and knocks over the next one, causing a chain reaction. This cause-and-effect relationship is what we call causality. We are generating the text sequentially, each word in the sentence depends on the previous word.
  • Automatic loading capability: We can just specify the model name to AutoModelForCausalLM and it will automatically load everything up for you.
# Load base model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map
)
model.config.use_cache = False
model.config.pretraining_tp = 1

Every LLM needs a tokenizer, it's responsible for converting our input into word embeddings.
So, we'll be loading the Tokenizer and the LoRA Configuration.

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

# Load LoRA configuration
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
)

We have set everything up, now we can start training our model.
We will define the training arguments first.

# Set training parameters
training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="tensorboard"
)

Using these arguments we will start our supervised fine-tuning.

# Set supervised fine-tuning parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=packing,
)

# Train model
trainer.train()

After the training is done, we can save it

# Save trained model
trainer.model.save_pretrained(new_model)

Now let's run a query through our trained model and see the output.

Here I have run the prompt, "How many servers are there in Hexmos?"

# Run text generation pipeline with our next model
prompt = "What is a large language model?"
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200)
result = pipe(f"<s>[INST] {prompt} [/INST]")
print(result[0]['generated_text'])

Here we get the following result from the model
LoRA Output

I have tried to prompt in different ways as well, and it gives the relevant answer.
alt text
alt text

Conclusion

As you saw throughout the article, LoRA allows us to harness the power of large language models and train them with our custom dataset with minimal resources. The range of opportunities this opens up is amazing, Ranging from Building your own Finetuned LLM for your use case to Building your own LLM Product. Although optimizations like these will never match the full fine-tuning, they will be beneficial to those with fewer device resources. Making this method ensure accessibility to everyone.

Reference

FeedZap: Read 2X Books This Year

FeedZap helps you consume your books through a healthy, snackable feed, so that you can read more with less time, effort and energy.