June 2023
tl;dr: An Efficient Training Speed Improvement for Large Language Models (LLMs)
. Where
,
,
is low-rank representation and
is the scale factor. We keep the original weight
frozen and only train the new matrices
and
.
import torch
import torch.nn
embed_dim = 128
output_dim = 512
rank = 4 # the rank 'r'
scale_factor = 5 # alpha
W = ... # Pretrained weight from network with shape (embed_dim x output_dim)
W_A = nn.Parameter(torch.randn(embed_dim, rank)) # LoRA weight A
W_B = nn.Parameter(torch.randn(rank, output_dim)) # LoRA weight B
def lora_forward(self, x):
hidden_layer = x @ W # the normal forward
hidden_layer += x @ scale_factor * (W_A @ W_B)
return hidden_layer