Parameter-Efficient Fine-Tuning (PEFT)#

.. 3 years in Singapore and still quite fascinated by Singlish… Then one random day, I gave myself a challenge: what if I fine-tuned a language model to translate English into Singlish? Just to kill time on a Sunday, lah 😉

Theory time: Base models are typically intelligent token tumblers or generators - they are not good at specific task like Q&A, reasoning, translations etc. Fine-tuning is the process of adapting a pre-trained model to a specific task by training it further on a task-specific dataset. One of the technique of fine-tuning is Parameter-Efficient Fine-Tuning (PEFT), which allows you to fine-tune large language models (LLMs) with a small number of trainable parameters.

The task at hand is to translate English sentences into Singlish - probably we could…

  • load a quantized model using BitsAndBytes

  • configure low-rank adapters (LoRA) using Hugging Face’s peft

  • load and format a dataset

  • fine-tune the model using the supervised fine-tuning trainer (SFTTrainer) from Hugging Face

  • use the fine-tuned model to generate a few sentences

🔘 Before we get go–lets brush some quick theory 🔘

LoRA (Low-Rank Adaptation): LoRA#

Attach trainable low-rank layers to existing layers to learn task-specific patterns without modifying the original model weights

PEFT (Parameter-Efficient Fine-Tuning): PEFT#

  • Freezes pretrained model weights and updates only LoRA adapters.

  • Why?

    • Reduces memory usage and computational costs while retaining model performance

    • Preserving the model’s broad knowledge base

!pip install transformers==4.46.2 peft==0.13.2 accelerate==1.1.1 trl==0.12.1 bitsandbytes==0.45.2 datasets==3.1.0 huggingface-hub==0.26.2 safetensors==0.4.5 pandas==2.2.2 matplotlib==3.8.0 numpy==1.26.4
import os
import torch
from datasets import load_dataset, Dataset
# PEFT (Parameter-Efficient Fine-Tuning)
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

Loading a Quantized Base Model#

We start by loading a quantized model to reduce its memory footprint on the GPU. A quantized model replaces the original weights with approximate values represented using fewer bits. (🚧 TODO: need to learn quantization in depth and document in a separate notebook 🚧). The simplest and most straightforward way to quantize a model is to convert its weights from 32-bit floating-point (FP32) numbers to 4-bit floating-point numbers (NF4). This simple yet powerful change can reduce the model’s memory footprint by roughly a factor of eight

We can use an instance of BitsAndBytesConfig as the quantization_config argument while loading a model using the from_pretrained() method.

bnb_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.float32
)
repo_id = 'microsoft/Phi-3-mini-4k-instruct'
model = AutoModelForCausalLM.from_pretrained(repo_id,
            device_map="cuda:0",
            quantization_config=bnb_config
)

"The Phi-3-Mini-4K-Instruct is a 3.8B parameters, lightweight, state-of-the-art open model trained with the Phi-3 datasets that includes both synthetic data and the filtered publicly available websites data with a focus on high-quality and reasoning dense properties. The model belongs to the Phi-3 family with the Mini version in two variants 4K and 128K which is the context length (in tokens) that it can support."
Source: Hugging Face Hub

See how much space it occupies in memory using the get_memory_footprint() method.

print(model.get_memory_footprint()/1e6)
2206.347264

Even after being quantized, the model still takes up a bit more than 2 gigabytes of RAM. The quantization procedure focuses on the linear layers within the Transformer decoder blocks (also referred to as “layers” in some cases):

model
Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3SdpaAttention(
          (o_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear4bit(in_features=3072, out_features=9216, bias=False)
          (rotary_emb): Phi3RotaryEmbedding()
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear4bit(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear4bit(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
        (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
      )
    )
    (norm): Phi3RMSNorm((3072,), eps=1e-05)
  )
  (lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)

A quantized model can be used directly for inference, but it cannot be trained any further. Those pesky Linear4bit layers take up much less space, which is the whole point of quantization; however, we cannot update them.

Setting Up Low-Rank Adapters (LoRA)#

Low-rank adapters can be attached to each of the quantized layers. These adapters are typically standard Linear layers that can be updated during training. The trick here is that the adapters are significantly smaller than the layers they augment.

Since the quantized layers are frozen (i.e., not updated during training), adding LoRA adapters drastically reduces the number of trainable parameters—often to just 1% (or less) of the model’s original size.

We can set up LoRA adapters in three easy steps:

  • Call prepare_model_for_kbit_training() to improve numerical stability during training.

  • Create an instance of LoraConfig.

  • Apply the configuration to the quantized base model using the get_peft_model() method.

"""
# PEFT (Parameter-Efficient Fine-Tuning)
function: prepare_model_for_kbit_training

this method only works for `transformers` models.

    This method wraps the entire protocol for preparing a model before running a training. This includes:
        1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
        head to fp32 4- Freezing the base model layers to ensure they are not updated during training
"""

model_with_lora = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r=8,                   # the rank of the adapter, the lower the fewer parameters you'll need to train
    lora_alpha=16,         # multiplier, usually 2*r
    bias="none",           # BEWARE: training biases *modifies* base model's behavior
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
    # Newer models, such as Phi-3 at time of writing, may require
    # manually setting target modules
    target_modules=['o_proj', 'qkv_proj', 'gate_up_proj', 'down_proj'],
)

model_with_lora = get_peft_model(model_with_lora, config)
model_with_lora
PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (embed_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3SdpaAttention(
              (o_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (qkv_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=9216, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=9216, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (rotary_emb): Phi3RotaryEmbedding()
            )
            (mlp): Phi3MLP(
              (gate_up_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=16384, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=16384, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (down_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=8192, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=8192, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (activation_fn): SiLU()
            )
            (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
            (resid_attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
            (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
          )
        )
        (norm): Phi3RMSNorm((3072,), eps=1e-05)
      )
      (lm_head): Linear(in_features=3072, out_features=32064, bias=False)
    )
  )
)

The output of the other three LoRA layers (qkv_proj, gate_up_proj, and down_proj) was suppressed to shorten the output.

The quantized layers (Linear4bit) have turned into lora.Linear4bit modules where the quantized layer itself became the base_layer with some regular Linear layers (lora_A and lora_B) added to the mix.

These extra layers would make the model only slightly larger. However, the model preparation function (prepare_model_for_kbit_training()) turned every non-quantized layer to full precision (FP32), thus resulting in a 30% larger model:

print(f'Size of model with LoRA adapters: {model.get_memory_footprint()/1e6:.2f}M')
Size of model with LoRA adapters: 2651.08M

Since most parameters are frozen, only a tiny fraction of the total number of parameters are currently trainable

trainable_parms, tot_parms = model_with_lora.get_nb_trainable_parameters()
print(f'Trainable parameters:             {trainable_parms/1e6:.2f}M')
print(f'Total parameters:                 {tot_parms/1e6:.2f}M')
print(f'% of trainable parameters:        {100*trainable_parms/tot_parms:.2f}%')
Trainable parameters:             12.58M
Total parameters:                 3833.66M
% of trainable parameters:        0.33%

English to Singlish Dataset#

This is a synthetically generated dataset using ChatGPT4.1. It is a simple list of 500 English sentences and their Singlish translations stored as CSV. Lets load the dataset and create a prompt shape for the model training…

import pandas as pd

df = pd.read_csv('/singlish_to_english_v0.1.csv')

dataset = Dataset.from_pandas(df)

dataset = dataset.rename_column("english", "prompt")
dataset = dataset.rename_column("singlish", "completion")
dataset = dataset.remove_columns(["index"])

print("Sample first data from dataset")
print(dataset[0])

print("Prompt structure")
messages = [
    {"role": "user", "content": dataset[0]['prompt']},
    {"role": "assistant", "content": dataset[0]['completion']}
]
messages
Sample first data from dataset
{'completion': 'Eh, you know how to chop the garlic not?', 'prompt': 'Hey, do you know how to chop the garlic?'}
Prompt structure
[{'role': 'user', 'content': 'Hey, do you know how to chop the garlic?'},
 {'role': 'assistant', 'content': 'Eh, you know how to chop the garlic not?'}]

Tokenizer#

We need to load the tokenizer that corresponds to our model. The tokenizer is an important part of this process, determining how to convert text into tokens in the same way used to train the model.

For instruction/chat models, the tokenizer also contains its corresponding chat template that specifies:

  • Which special tokens should be used, and where they should be placed.

  • Where the system directives, user prompt, and model response should be placed.

  • What is the generation prompt, that is, the special token that triggers the model’s response (more on that in the “Querying the Model” section)

tokenizer = AutoTokenizer.from_pretrained(repo_id)
tokenizer.chat_template
print(tokenizer.apply_chat_template(messages, tokenize=False))
<|user|>
Hey, do you know how to chop the garlic?<|end|>
<|assistant|>
Eh, you know how to chop the garlic not?<|end|>
<|endoftext|>

Fine-Tuning with SFTTrainer#

Fine-tuning a model, whether large or otherwise, follows exactly the same training procedure as training a model from scratch. We could write our own training loop in pure PyTorch, or we could use Hugging Face’s Trainer to fine-tune our model.

It is much easier, however, to use SFTTrainer instead (which uses Trainer underneath, by the way), since it takes care of most of the nitty-gritty details for us, as long as we provide it with the following four arguments:

  • a model

  • a tokenizer

  • a dataset

  • a configuration object

SFTConfig#

There are many parameters that we can set in the configuration object. We have divided them into four groups:

  • Memory usage optimization parameters related to gradient accumulation and checkpointing

  • Dataset-related arguments, such as the max_seq_length required by your data, and whether you are packing or not the sequences

  • Typical training parameters such as the learning_rate and the num_train_epochs

  • Environment and logging parameters such as output_dir (this will be the name of the model if you choose to push it to the Hugging Face Hub once it’s trained), logging_dir, and logging_steps.

While the learning rate is a very important parameter (as a starting point, you can try the learning rate used to train the base model in the first place), it’s actually the maximum sequence length that’s more likely to cause out-of-memory issues.

Make sure to always pick the shortest possible max_seq_length that makes sense for your use case. In ours, the sentences—both in English and Yoda-speak—are quite short, and a sequence of 64 tokens is more than enough to cover the prompt, the completion, and the added special tokens.

Flash attention, allows for more flexibility in working with longer sequences, avoiding the potential issue of OOM errors.

sft_config = SFTConfig(
    ## GROUP 1: Memory usage
    # These arguments will squeeze the most out of your GPU's RAM
    # Checkpointing
    gradient_checkpointing=True,
    # this saves a LOT of memory
    # Set this to avoid exceptions in newer versions of PyTorch
    gradient_checkpointing_kwargs={'use_reentrant': False},
    # Gradient Accumulation / Batch size
    # Actual batch (for updating) is same (1x) as micro-batch size
    gradient_accumulation_steps=1,
    # The initial (micro) batch size to start off with
    per_device_train_batch_size=16,
    # If batch size would cause OOM, halves its size until it works
    auto_find_batch_size=True,

    ## GROUP 2: Dataset-related
    max_seq_length=64,
    # Dataset
    # packing a dataset means no padding is needed
    packing=True,

    ## GROUP 3: These are typical training parameters
    num_train_epochs=10,
    learning_rate=3e-4,
    # Optimizer
    # 8-bit Adam optimizer - doesn't help much if you're using LoRA!
    optim='paged_adamw_8bit',

    ## GROUP 4: Logging parameters
    logging_steps=10,
    logging_dir='./logs',
    output_dir='./phi3-mini-singlish-adapter',
    report_to='none'
)
trainer = SFTTrainer(
    model=model_with_lora,
    processing_class=tokenizer,
    args=sft_config,
    train_dataset=dataset,
)
dl = trainer.get_train_dataloader()
batch = next(iter(dl))
batch['input_ids'][0], batch['labels'][0]
(tensor([ 1809,   297,   278,  1510,   261,  2745,   306,  2609,   260,   801,
           273, 29889, 32007, 32000, 32000, 32010, 18637, 29892,   437,   366,
          1073,   920,   304,  1708,   278, 11210, 29973, 32007, 32001,   382,
         29882, 29892,   366,  1073,   920,   304,  1708, 11210,   470,   451,
         29973, 32007, 32000, 32000, 32010, 18637, 29892,   366,  2253,   748,
           304,  6592,  1286, 29892,   366,  1603,   505,   304,   281,  1296,
           701,  4688,  6454, 22396], device='cuda:0'),
 tensor([ 1809,   297,   278,  1510,   261,  2745,   306,  2609,   260,   801,
           273, 29889, 32007, 32000, 32000, 32010, 18637, 29892,   437,   366,
          1073,   920,   304,  1708,   278, 11210, 29973, 32007, 32001,   382,
         29882, 29892,   366,  1073,   920,   304,  1708, 11210,   470,   451,
         29973, 32007, 32000, 32000, 32010, 18637, 29892,   366,  2253,   748,
           304,  6592,  1286, 29892,   366,  1603,   505,   304,   281,  1296,
           701,  4688,  6454, 22396], device='cuda:0'))

The labels were added automatically, and they’re exactly the same as the inputs. Thus, this is a case of self-supervised fine-tuning.

The shifting of the labels will be handled automatically as well; there’s no need to be concerned about it.

Although this is a 3.8 billion-parameter model, the configuration above allows us to squeeze training, using a mini-batch of eight, into an old setup with a consumer-grade GPU such as a GTX 1060 with only 6 GB RAM. True story!

Next, we call the train() method and wait:

import time

start_time = time.time()
trainer.train()

end_time = time.time()

elapsed_time = end_time - start_time
minutes, seconds = divmod(elapsed_time, 60)
print(f"Total training time: {int(minutes)}m {int(seconds)}s")
[190/190 20:55, Epoch 10/10]
Step Training Loss
10 3.010700
20 1.871600
30 1.568500
40 1.390200
50 1.188700
60 1.077000
70 0.851500
80 0.749200
90 0.554100
100 0.469200
110 0.374000
120 0.328100
130 0.293400
140 0.256900
150 0.255200
160 0.230800
170 0.230300
180 0.206000
190 0.204000

Total training time: 22m 0s
trainer.save_model('/model/lora-phi3-mini-singlish-adapter')

Querying the Model#

So, the model requires its inputs to be properly formatted. We need to build a list of “messages”, from the user and prompt the model to answer by indicating it’s its turn to write.This is the purpose of the add_generation_prompt argument: it adds <|assistant|> to the end of the conversation, so the model can predict the next word—and continue doing so until it predicts an <|endoftext|> token.

The helper function below assembles a message (in the conversational format) and applies the chat template to it, appending the generation prompt to its end.

def gen_prompt(tokenizer, sentence):
    converted_sample = [
        {"role": "user", "content": sentence},
    ]
    prompt = tokenizer.apply_chat_template(converted_sample,
                                           tokenize=False,
                                           add_generation_prompt=True)
    return prompt

Helper Function#

  • It tokenizes the prompt into a tensor of token IDs (add_special_tokens is set to False because the tokens were already added by the chat template).

  • It sets the model to evaluation mode.

  • It calls the model’s generate() method to produce the output (generated token IDs).

  • It decodes the generated token IDs back into readable text.

def extract_response(text):
  # Find the position of '<|assistant|>'
  start_pos = text.find('<|assistant|>')

  if start_pos != -1:
      extracted_text = text[start_pos + len('<|assistant|>'):].strip()  # Remove any leading spaces
      extracted_text = extracted_text.replace("<|end|>", "")
      extracted_text = extracted_text.replace("<|endoftext|>", "")
      return extracted_text
  else:
      return "I dont understand you..."


def generate(model, tokenizer, prompt, max_new_tokens=64, skip_special_tokens=False):
    tokenized_input = tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to(model.device)

    model.eval()
    generation_output = model.generate(**tokenized_input,
                                       eos_token_id=tokenizer.eos_token_id,
                                       max_new_tokens=max_new_tokens)

    output = tokenizer.batch_decode(generation_output,
                                    skip_special_tokens=skip_special_tokens)
    return extract_response(output[0])

Testing the Model#

Will load the original model and generate the reponse against original and fine-tuned model. The original model is the quantized Phi-3-Mini-4K-Instruct, and the fine-tuned model is the one we just trained with LoRA adapters.

# Loading the original Model

bnb_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.float32
)
repo_id = 'microsoft/Phi-3-mini-4k-instruct'
original_model = AutoModelForCausalLM.from_pretrained(repo_id,
                                             device_map="cuda:0",
                                             quantization_config=bnb_config
)
test_sentences = [
    "Wow, this stir-fried rice noodle dish is really amazing!",
    "I sneezed once and my mum already thinks I'm dying.",
    "No matter how many times I explain, he still doesn't get it.",
    "I'm telling you, the pillow was so comfortable, I fell asleep immediately."
    "She dressed up like she was going to a gala, just to buy groceries.",
    "The train broke down again and I was already late.",
    "He keeps saying he's an expert but does nothing at all.",
    "He got angry over nothing, like someone stole his lunch."
]

for i, sentence in enumerate(test_sentences, 1):
    prompt = gen_prompt(tokenizer, sentence)
    original_output = generate(original_model, tokenizer, prompt)
    lora_output = generate(model_with_lora, tokenizer, prompt)

    print(f"\n ===== Test {i} =====")
    print(f"Prompt: \n{prompt}")
    print("----------")
    print(f"⚪️ Original Model response: \n💬 {original_output}")
    print("----------")
    print(f"⭕️ LoRA Fine-Tuned response: \n💬 {lora_output}")
    print("----------")
 ===== Test 1 =====
Prompt: 
<|user|>
Wow, this stir-fried rice noodle dish is really amazing!<|end|>
<|assistant|>

----------
⚪️ Original Model response: 
💬 I'm glad you're enjoying it! Stir-fried rice noodles can be a delightful meal with the right combination of flavors and ingredients. What do you think makes this dish stand out for you?
----------
⭕️ LoRA Fine-Tuned response: 
💬 Wah, this char kway teow really power lah!
----------

 ===== Test 2 =====
Prompt: 
<|user|>
I sneezed once and my mum already thinks I'm dying.<|end|>
<|assistant|>

----------
⚪️ Original Model response: 
💬 I'm sorry to hear that you're going through this. Sneezing is a common reflex that helps clear irritants from your nose and throat. It's important to remember that sneezing doesn't necessarily indicate a serious health issue. However, if you're concerned
----------
⭕️ LoRA Fine-Tuned response: 
💬 I sneeze once, my mother think I die sia.
----------

 ===== Test 3 =====
Prompt: 
<|user|>
No matter how many times I explain, he still doesn't get it.<|end|>
<|assistant|>

----------
⚪️ Original Model response: 
💬 It sounds like you're facing a communication challenge. To help him understand, try these strategies:

1. Simplify your language: Use clear, simple words and avoid jargon or complex sentences.
2. Be patient: Give him time to process the information.
3. Use visual
----------
⭕️ LoRA Fine-Tuned response: 
💬 I try to explain many times, but he still don't understand one.
----------

 ===== Test 4 =====
Prompt: 
<|user|>
I'm telling you, the pillow was so comfortable, I fell asleep immediately.She dressed up like she was going to a gala, just to buy groceries.<|end|>
<|assistant|>

----------
⚪️ Original Model response: 
💬 It seems like you're expressing two separate scenarios. Let's address them one by one.

1. Regarding the pillow:
"I'm telling you, the pillow was so comfortable, I fell asleep immediately."

This sentence suggests that you were very tired and the pill
----------
⭕️ LoRA Fine-Tuned response: 
💬 I tell you ah, the pillow so comfy, I straight away fall asleep leh.Dress like go sewing event, actually just go buy NTUC.
----------

 ===== Test 5 =====
Prompt: 
<|user|>
The train broke down again and I was already late.<|end|>
<|assistant|>

----------
⚪️ Original Model response: 
💬 I'm sorry to hear that your train experience was not smooth today. Trains are a common mode of transportation, and while they are generally reliable, occasional breakdowns can happen. It'ieves a lot of people, especially those who rely on punctuality for their daily routines. If
----------
⭕️ LoRA Fine-Tuned response: 
💬 The train kena spoil again, I already late.
----------

 ===== Test 6 =====
Prompt: 
<|user|>
He keeps saying he's an expert but does nothing at all.<|end|>
<|assistant|>

----------
⚪️ Original Model response: 
💬 It seems like you're expressing frustration with someone who claims to be an expert but isn't taking any action. If you're dealing with a specific individual, it might be helpful to communicate your concerns directly to them. You could say something like, "I've noticed that despite your claims of
----------
⭕️ LoRA Fine-Tuned response: 
💬 The one say expert but do nothing sia.
----------

 ===== Test 7 =====
Prompt: 
<|user|>
He got angry over nothing, like someone stole his lunch.<|end|>
<|assistant|>

----------
⚪️ Original Model response: 
💬 It seems like you're describing a situation where someone is overreacting to a minor issue. This can happen when a person's emotional response is disproportionate to the event that triggered it. In this case, the person is upset about a personal belonging, specifically their lunch, being
----------
⭕️ LoRA Fine-Tuned response: 
💬 He kena angry over nothing, like someone steal his lunch.
----------