A Deep Dive Into Low-Rank Adaptation (LoRA)


In the evolving landscape of large language models (LLMs) and the almost infinite number of use cases that they can help in, the ability to fine-tune them efficiently and effectively stands out as a challenge especially with both limited resources and data. This is where Low-Rank Adaptation (LoRA) technique appeared to be a game-changer enabling fine-tuning very big LLMs to specific task on limited resources and datasets.

LoRA introduces a seemingly-simple yet powerful and cost-effective way to fine-tune LLMs and adapt them to a specific task by integrating low-rank matrices into the model’s architecture. Put simple, instead of retraining the whole model’s parameters over our data, we identify specific layers in the model’s architecture that we would like to tweak and adjust and represent their weight matrices with smaller matrices and retrain these smaller ones instead in a parameter-efficient way. This approach significantly reduces the number of trainable parameters but still enable the base model to adapt to a specific task efficiently.

In this blog post, we will delve deep into how LoRA works under the hood, looking at its mathematical foundations by walking through a practical example. By the end of this post, you’ll see how useful and necessary LoRA is when it comes to fine-tuning large language models.

Understanding LoRA

So what are low-rank matrices about and how do they work? Let’s consider the linear equation:

y = Wx + b

Where W is the weight matrix (coefficients) and b is the intercept (bias) and of cource x is input and y is output.

using LoRA, instead of updating the entire W which can be too big, we introduce two new low-rank matrices A and B whose dot product can approximate the adaptation that we might want to apply to W. Mathematically speaking we want to achieve:

W’ (adapter) = W + AB

So now we can compute y by both leveraging the base model knowledge along with the adaptation to the task using the following equation:

y’ (adapted) = Wx + ABx + b

So the main question here is how can we get A and B that can really represent W.

Low-Rank Approximation with SVD

Singular Value Decomposition (SVD) is a technique that decomposes (factorization) a matrix W into three other matrices:

W = USVt

Where U and V_transpose are orthogonal matrices and S is a diagonal matrix containing the singular values. For low-rank approximation, we take the first largest values and the corresponding columns from U and rows from V_transpose to achieve an approximation Wr of W. In mathematical notation we will have:

Wr = UrSrVtr 

So how is this related to LoRA? Remember that we needed to define A and B? So we can get them from these low-rank three matrices in the following way:

A = Ur * SQRT(Sr) & B = SQRT(Sr) * Vtr

Let’s try to approximate a matrix using this logic in Pytorch.

import torch

# Original matrix
W = torch.rand(4, 4)
# Perform Singular Value Decomposition
U, S, Vt = torch.svd(W)
# Rank r = 2 for low-rank approximation
# if we set r = 4, we reconstruct the same original matrix
r = 2
Ur = U[:, :r]
Sr = torch.diag(S[:r])
Vtr = Vt.t()[:r, :]

# Initialize A and B using the SVD components
A = torch.mm(Ur, torch.sqrt(Sr))
B = torch.mm(torch.sqrt(Sr), Vtr) #V transpose

# Approximate W from A and B
W_approx = torch.mm(A, B)

print("Original W:\n", W)
print("Approximated W:\n", W_approx)

# Original W:
#  tensor([[0.0918, 0.4794, 0.8106, 0.0151],
#         [0.0153, 0.6036, 0.2318, 0.8633],
#         [0.9859, 0.1975, 0.0830, 0.4253],
#         [0.9149, 0.4799, 0.5348, 0.2695]])
# Approximated W:
#  tensor([[0.0374, 0.5079, 0.4906, 0.3878],
#         [0.0747, 0.5884, 0.5673, 0.4529],
#         [1.0059, 0.2161, 0.1760, 0.2814],
#         [0.8896, 0.4589, 0.4149, 0.4510]])

As you can see, we get a decent approximation of the original matrix with the low-rank A and B because we set r = 2. However, if we set r=4 we will get the original W. Try it!.
So instead of updating W which can be too big, we use A and B as its approximate representation. Of course this comes with some cost which is somehow lower accuracy.

Great! Now we want to know to what value we should set r? We know it should of course be smaller than the number of dimensions in original W but how smaller it should be?

LoRA Rank and Alpha

The rank (r) represents the dimensionality of the matrices A and B which essentially controls the complexity of the adaptation. In a few words, it determines how close the low-rank matrices to the original matrix, i.e. the higher r the closer we get to original matrix and the higher the accuracy but the slower the fine-tuning process will be and also the more resources we might need. And vice versa if we choose a smaller r.

So how to choose r? It depends on the specifics of your task and how you want to adapt the base model to it. If you see the base model performs already very well but you want to slightly adjust it teaching it some fine-grained details, you’d better choose small r (in range of 4 to 16). But in case you want to teach the base model a somewhat new domain allowing for a more complex adaptation, you may need to set r to a higher range (32 to 128 or bigger). Take care that high r while you don’t have enough data may lead the model to overfitting.

There is another important parameter in LoRA which is alpha (a). Alpha is a scaling factor applied to the dot product of AB to determine the magnitude of the adaptation in the final prediction.
Let’s see the effect of alpha mathematically. Remember the equation to predict y’:

y’ (adapted) = Wx + ABx + b

So to add alpha we will have:

y’ (adapted) = Wx + (alpha * ABx) + b

As you can see and infer, the higher alpha, the more effect AB will have on the final prediction. Alpha can help balance between retaining the pre-trained knowledge from the base model (when a is small) and adapting to the new task (when a is larger).

Practical Fine-tuning Example with LoRA

Now we know the ins and outs of LoRA, let’s walk through a simple neural network example to see how it works.
We will build a pytorch model and train it on MNIST data. The trick is that we will only train the base model it on hand-written images of numbers from 0 to 7 only, leaving out 8 and 9 (as our specific task data). Then we will use the base model, that never saw 8 and 9 (our task), and fine-tune some of its layers using LoRA and see if it learns to classify them.

Let’s first get the data

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, Subset
from copy import deepcopy

# Load and preprocess the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Filter out data for digits 8 and 9
train_idx = mnist_train.targets < 8
test_idx = mnist_test.targets < 8

mnist_Xtrain_0to7 = mnist_train.data[train_idx].float()
mnist_ytrain_0to7 = mnist_train.targets[train_idx]

mnist_Xtest_0to7 = mnist_test.data[test_idx].float()
mnist_ytest_0to7 = mnist_test.targets[test_idx]

# a class to pass to the dataloader
# mnist_train itself can be passed but we will divide it to so many parts
class MNISTDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return [self.X[idx], self.Y[idx]]

mnist_train_0to7 = MNISTDataset(mnist_Xtrain_0to7, mnist_ytrain_0to7)
mnist_test_0to7 = MNISTDataset(mnist_Xtest_0to7, mnist_ytest_0to7)
train_loader = DataLoader(mnist_train_0to7, batch_size=64, shuffle=True)
test_loader = DataLoader(mnist_test_0to7, batch_size=1000, shuffle=False)

Now let’s build the base model and train it

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.Linear(512, 256),
            nn.Linear(256, 10)  # Output layer for 10 classes
    def forward(self, x):
        return self.layers(x)

model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Train
for epoch in range(5):  # Train for 5 epochs for demonstration
    for data, target in train_loader:
        output = model(data)
        loss = criterion(output, target)
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Epoch 1, Loss: 0.017020845785737038
# Epoch 2, Loss: 0.016323702409863472
# Epoch 3, Loss: 0.012023240327835083
# Epoch 4, Loss: 0.0017469731392338872
# Epoch 5, Loss: 5.717848034691997e-05

Now let’s evaluate the model on the test data.

y_hat = model(mnist_Xtest_0to7)
((torch.argmax(y_hat, dim = 1)) == mnist_ytest_0to7).sum() / mnist_ytest_0to7.size()[0])

# accuracy: tensor(0.9808)

Great! We have a nice base model. Now let’s define the data for 8 and 9 digits and test the base model on them to see how it performs.

# The data for digits 8 and 9
train_idx_89 = mnist_train.targets >= 8
test_idx_89 = mnist_test.targets >= 8

mnist_Xtrain_89 = mnist_train.data[train_idx_89].float()
mnist_ytrain_89 = mnist_train.targets[train_idx_89]

mnist_Xtest_89 = mnist_test.data[test_idx_89].float()
mnist_ytest_89 = mnist_test.targets[test_idx_89]

mnist_train_89 = MNISTDataset(mnist_Xtrain_89, mnist_ytrain_89)
mnist_test_89 = MNISTDataset(mnist_Xtest_89, mnist_ytest_89)
train89_loader = DataLoader(mnist_train_89, batch_size=64, shuffle=True)
test89_loader = DataLoader(mnist_test_89, batch_size=1000, shuffle=False)

y_hat = model(mnist_Xtest_89)
((torch.argmax(y_hat, dim = 1)) == mnist_ytest_89).sum() / mnist_ytest_89.size()[0])

# accuracy: tensor(0.)

Cool! The model cannot predict 8 or 9. Now we will define our LoRA model that will only low-rank fine-tune the last layer only for simplicity and demonstration.

class LoRAMLP(nn.Module):
    def __init__(self, base_model, rank=16, alpha=1.0):
        super(LoRAMLP, self).__init__()
        self.base_model = base_model
        self.rank = rank
        self.alpha = alpha

        # Identify the layer to adapt (the last layer)
        layer_to_adapt = base_model.layers[-1]

        # Freeze the parameters of the base model
        for param in self.base_model.parameters():
            param.requires_grad = False

        # Perform SVD on the weight matrix of the layer to adapt
        U, S, V = torch.svd(layer_to_adapt.weight.detach())

        # Initialize A and B using the top-r singular values and vectors
        self.A = nn.Parameter(U[:, :rank] @ torch.diag(S[:rank]).sqrt())
        self.B = nn.Parameter(torch.diag(S[:rank]).sqrt() @ V.t()[:rank, :])

    def forward(self, x):
        x = self.base_model.layers[:-1](x)  # Use all layers except the last one
        # Apply the LoRA modification and scale by alpha
        Wx = torch.matmul(x, self.base_model.layers[-1].weight.t())
        AB = (self.A @ self.B)
        ABx = self.alpha * torch.matmul(x, AB.t())
        adapted_y = Wx + ABx + self.base_model.layers[-1].bias
        return adapted_y         
# Initialize the LoRA-adapted model
lora_model = LoRAMLP(deepcopy(model), rank = 4, alpha = 1)
optimizer = optim.Adam(lora_model.parameters(), lr=5e-3)  # Optimizer for LoRA parameters

# Fine-tune on digits 8 and 9
for epoch in range(3):  # Fine-tune for 3 epochs
    for data, target in train89_loader:
        output = lora_model(data)
        loss = criterion(output, target)  # you can use (target -8) to adjust targets for 0-1 range
    print(f"Fine-tuning Epoch {epoch+1}, Loss: {loss.item()}")

# Fine-tuning Epoch 1, Loss: 0.18958234786987305
# Fine-tuning Epoch 2, Loss: 0.12145351618528366
# Fine-tuning Epoch 3, Loss: 0.10071035474538803

Let’s now test it on the data of 8 and 9 digits

lora_y_hat = lora_model(mnist_Xtest_89)
((torch.argmax(lora_y_hat, dim = 1)) == mnist_ytest_89).sum() / mnist_ytest_89.size()[0])

# accuracy: tensor(0.9375)

Awesome! We have successfully fine-tuned the model on our specific task and saw how to do so in a parameter-efficient way with LoRA. Let’s finally look at how much memory we saved looking at the trainable parameters in both the base model and the LoRA model.

base_model_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
lora_trainable_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad)
print(f"original trainable parameters count is: {base_model_trainable_params} which is 100%")
print(f"trainable parameters count with LoRA is:
       {lora_trainable_params} which is {(lora_trainable_params/base_model_trainable_params) * 100}%")

# original trainable parameters count is: 535818 which is 100%
# trainable parameters count with LoRA is: 1064 which is 0.20%

We reduced the number of trainable parameters significantly while still having a very good fine-tuned model.

LoRA in Fine-tuning LLMs

The same technique is applied on LLMs but on multiple different layers that would be very difficult to do it manually. Luckily HuggingFace has PeFT library that can help achive LoRA adaptation efficiently. We will look at a full use case in details in next blog posts soon. However, as a quick illustration, here’s how we can do it with peft.

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import TrainingArguments
from trl import SFTTrainer

lora_config = LoraConfig(r = 32, lora_alpha = 32,
                        lora_dropout = 0.05, bias = "none",
                        task_type = "CAUSAL_LM",
                        target_modules = ["q_proj", "v_proj", "k_proj"])
model = prepare_model_for_kbit_training(base_model)
model = get_peft_model(model, lora_config)

training_args = TrainingArguments(
   per_device_train_batch_size = 3,
   warmup_steps = 0.03,
   fp16 = not torch.cuda.is_bf16_supported(),
   bf16 = torch.cuda.is_bf16_supported(),
   learning_rate = 5e-4,
   weight_decay = 0.1,
   lr_scheduler_type = "cosine",
   seed = 1311,
   output_dir = "outputs",

trainer = SFTTrainer(
   model = model,
   train_dataset = data,
   eval_dataset = data,
   max_seq_length = max_seq_len,
   args = training_args,
   packing= True,
   formatting_func = formatting_func,

trainer_stats = trainer.train()



Leave a Reply

Your email address will not be published. Required fields are marked *