Fine-tuning TinyLlama with Axolotl and JarvisLab

Fine-tuning TinyLlama to generate David Attenborough style narration - using Axolotl, JarvisLab, and OpenAI’s Batch API
llm
code
Author

Gabriel Chua

Published

June 16, 2024

I recently signed up for the LLM Fine-Tuning course/conference. Prior to this, most of my fine-tuning experience was either through OpenAI’s fine-tuning API for GPT 3.5, or through the starter scripts in the MLX Examples repo.

One tool introduced I learnt in the course was Axolotl. Basically, with Axolotl, I can easily fine-tune LLMs through the CLI based on configurations defined in a .yml file. Within the Axolotl GitHub repo, there are many helpful example .yml files one can easily re-use.

To run Axolotl, I’m using Jarvis Labs - a GPU cloud provider - that was also one of the generous GPU credits sponsors for the course. What I appreciated was how I could easily launch a template with Jupyter Notebook instance with the Axolot repo already cloned and other dependencies installed. For more detailed steps on getting started with Axolotl through Jarvis Labs, you can refer to this post by Andresckamilo which I also referred to.

To get my hands dirty and apply the above, I did up this toy example where I fine-tuned TinyLlama to generate David Attenborough style narration. The final model can be found here on HuggingFace, which you can also try.

The aim of this blog post is to document my learning process doing the above. Overall, the process took about 1 hour to prepare the data and fine-tune the model, though the actual fine-tuning took about 15 minutes. As such, this post also does lean more towards the data generation and preparation steps. Additionally, in terms of cost, the fine-tuning alone took slightly less than 1 USD.

Step 1: Generating the synthetic data

We begin by generating the conversation pairs. For this example, I used OpenAI’s models, specifically through the Batch API. This helps to reduce cost.

For the Batch API, we need to provide a .jsonl file where each line is the POST request to OpenAI’s chat completion endpoint. Here is a sample of what that looks like:

{"custom_id": "0", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4-turbo-2024-04-09", "messages": [{"role": "system", "content": "Imagine you are David Attenborough. I will give you an activity, and you will give a ~300 word narration for a documentary."}, {"role": "user", "content": "A young girl planting flowers in her backyard garden."}], "temperature": 1, "max_tokens": 500}}

We begin by importing the relevant packages:

import json
import random
import asyncio
from openai import AsyncOpenAI
from tqdm import tqdm

For this example, we need a persona and activity for which we’ll generate the David Attenborough style narration.

In the code block below, we define a list of possible personas in PERSONAS, and use GPT-3.5 to generate a random activity for each persona. We set a higher temperature to allow for more creativity, and 0.6 was chosen arbitrarily. Note that we are using the async client and functions, which will be relevant in the next code block.

For your example, feel free to modify this seeding method for your use case.

PERSONAS = ["young kid", "young girl", "young boy", "teenager", "middle-aged woman", "middle-aged man", "mother", "father", "grandmother", "grandfather"]

client = AsyncOpenAI()

async def fetch_activity(session, persona):
    response = await session.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{
            "role": "user",
            "content": f"Generate one random activity done by a {persona}. For example, `a young boy building a sand castle`, or `a lady in the grocery store`. Be creative."
        }],
        temperature=0.6,
        max_tokens=30
    )
    return response.choices[0].message.content

We then define the gather_activities function to make 20 concurrent GPT-3.5 API calls to generate the random activity.

async def gather_activities(number_rows):
    async with client as session:
        tasks = [fetch_activity(session, random.choice(PERSONAS)) for _ in range(number_rows)]
        activities = []
        for i in range(0, number_rows, 20):  # Process in chunks of 20
            chunk = tasks[i:i+20]
            activities.extend(await asyncio.gather(*chunk))
        return activities

Lastly, we have the main function that creates the .jsonl file. It takes two inputs: the name of the .jsonl file you want to export and the number of rows, which is the number of requests for OpenAI’s batch API

Note how:

  • asyncio.run is used to execute the async functions.
  • tqdm is used to track the JSONL generation.
  • For each line in the .jsonl file, we have the main API call to GPT-4 Turbo, and the system prompt is defined below.
  • For increased variation, we set the temperature to 1, which was again an arbitrary figure.
def generate_jsonl(filename, number_rows):
    """
    Generate a JSONL file with the specified filename
    """
    activities = asyncio.run(gather_activities(number_rows))

    # Write jsonl file
    with open(filename, 'w') as file:
        for index in tqdm(range(0, number_rows), desc="Generating JSONL File"):
            activity = activities[index]
            request = {
                "custom_id": str(index),
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": "gpt-4-turbo-2024-04-09",
                    "messages": [{
                        "role": "system",
                        "content": "Imagine you are David Attenborough. I will give you an activity, and you will give a ~300 word narration for a documentary."
                    },
                    {
                        "role": "user",
                        "content": f"{activity}"
                    }],
                    "temperature": 1,
                    "max_tokens": 500,
                }
            }
            file.write(json.dumps(request) + '\n')

We then execute the .jsonl creation. It’s worth stressing that this created david_attenborough.jsonl file is NOT the fine-tuning dataset. Instead, it’s a series of API calls that OpenAI will execute in batch, and the corresponding output of that can be formatted into our fine-tuning dataset.

# Specify the filename and number of rows
generate_jsonl('batch_api_input.jsonl', 500)

In this example, I used OpenAI’s developer portal GUI to upload the .jsonl file, although this can also be done programmatically. After about 3 minutes, the 500 API calls were successfully executed

Step 2: Preparing the fine-tuning dataset

The next step is to prepare the fine-tuning dataset for Axolotl.

The output from OpenAI’s batch API is also a .jsonl file with one line per successfully completed API call. Here is what the first line of my output looked like. Note that the original prompt is not included, so I needed to merge it back.

{"id": "batch_req_HPuUiqtanStr3Bww9ydozYAr", "custom_id": "0", "response": {"status_code": 200, "request_id": "b516cd0cd1879aaf525077b91ae5b816", "body": {"id": "chatcmpl-9T0yV3wVCPqREUWzNoUA51psLqYY7", "object": "chat.completion", "created": 1716702287, "model": "gpt-4-turbo-2024-04-09", "choices": [{"index": 0, "message": {"role": "assistant", "content": "In the bustling heart of an urban sprawl, where the concrete jungle often dwarfs splashes of green, a small, vibrant oasis emerges in the backyard of a suburban home. Here, amidst the buzz of daily life, we discover a young girl engaged in a ritual as timeless as nature itself: the planting of flowers.\n\nHer small hands delve into the earth, each movement thoughtful, as if she understands the profound dialogue between human and soil. This garden, her canvas, waits patiently, offering up its nutrients and space, eager to cradle new life. With a gentle touch, she places each seedling into its own nook, crafting a mosaic of potential blooms that could transform this patch into a riot of color and fragrance.\n\nAs she works, her focus is palpable. She is an artist, and the earth beneath her is both medium and collaborator. Each flower, carefully selected for its hue, texture, and the role it will play in the garden\u2019s grand symphony, is nurtured into its place. Marigolds, with their cheerful orange faces, promise a fireworks of blossoms, a natural pest repellent safeguarding her nascent creation. Lavender, with its soothing aroma, whispers of calm afternoons and the hum of busy bees.\n\nThis activity is more than mere horticulture; it's a lesson in patience, responsibility, and the interconnectedness of life. Our young gardener learns to foster growth, to respect cycles of life, and perhaps most poignantly, to deal with the inevitable losses that gardening, like life itself, occasionally demands.\n\nThe sun begins its descent, casting long shadows over freshly turned soil, transforming the ordinary into the magical. As the day ends, the garden is small still, its potential just beginning to bud. Yet, in the heart of this young botanist, dreams of spring bloom wildly\u2014a vivid, hopeful imagining of what is yet to unfold. In her hands lies not just the beauty of flowers, but the stewardship of nature itself, one backyard at a time."}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 49, "completion_tokens": 413, "total_tokens": 462}, "system_fingerprint": "fp_f17df3916f"}}, "error": null}

We then load and extract the assistant’s content from the output .jsonl file:

def load_and_extract_chat_completion(jsonl_file_path):
    contents = []
    
    with open(jsonl_file_path, 'r') as file:
        for line in file:
            data = json.loads(line)
            assistant_content = data.get("response", {}).get("body", {}).get("choices", [])[0].get("message", {}).get("content", "")
            if assistant_content:
                contents.append(assistant_content)
    
    return contents

answers = load_and_extract_chat_completion('batch_api_output.jsonl')

We also load and extract the original prompts.

def load_and_extract_prompt(jsonl_file_path):
    user_contents = []
    
    with open(jsonl_file_path, 'r') as file:
        for line in file:
            data = json.loads(line)
            messages = data.get("body", {}).get("messages", [])
            if len(messages) > 1:
                user_content = messages[1].get("content", "")
                if user_content:
                    user_contents.append(user_content)
    
    return user_contents

prompt = load_and_extract_prompt("batch_api_input.jsonl")

We’ve now stored the original prompt and the response in two separate Python lists. We will now combine them into the sharegpt format

# Create fine-tuning dataset
with open('david_attenborough_conversations.jsonl', 'w') as outfile:
    for i in range(len(prompt)):
        example = {
            "conversations": [
                {"role": "human", "value": f"Write a David Attenborough style commentary based on the following prompt: {prompt[i]}"},
                {"role": "assistant", "value": f"{answers[i]}"}
            ]
        }
        json.dump(example, outfile)
        outfile.write('\n')

Step 3: Uploading the dataset to HuggingFace

The next step is to upload it to HuggingFace - which we can then pull via Axolotl when running the fine-tuning job.

You can find the final dataset here on HuggingFace.

Step 4: Running Axolotl through JarvisLab

So now that we’ve prepared our data - we can begin the fine-tuning proper.

Ironically, this step took the least amount of time mainly because data preparation (rightfully) takes more time, but also due to the convenient abstractions provided by Axolotl.

I modified the dataset keys for the lora.yml for TinyLlama that was located at examples/tiny-llama/lora.yml.

datasets:
  - path: cyzgab/david_attenborough_style_narration
    type: sharegpt
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out

Then the next step is to begin the training job which was on about 475 training examples and 4 epochs. This took about 15 minutes to complete.

# start training job
accelerate launch -m axolotl.cli.train examples/tiny-llama/lora.yml

Once the training job is done, I ran the following command to launch a Gradio interface to interact with the model.

# gradio 
accelerate launch -m axolotl.cli.inference examples/tiny-llama/lora.yml --lora_model_dir="./outputs/lora-out" --gradio

Lastly, I ran the following command to fuse the lora with the original weights, which I then uploaded to HuggingFace.

# fuse model
python3 -m axolotl.cli.merge_lora examples/tiny-llama/lora.yml --lora_model_dir="./outputs/lora-out"

Conclusion

Overall, this was a fun gentle introduction to fine-tuning an open source model. Utilizing synthetic data allowed me to generate unique content tailored to the use case. This process was further streamlined by the powerful abstractions provided by tools like Axolotl and the Batch API, which made it quick and efficient