generate-synthetic-data-distillation(Python)

Loading...

Synthetic data generation with Llama 3.1 405B Instruct

This notebook demonstrates how to use Llama 3.1 405B to generate responses to a set of prompts. The resulting set of prompts with generated responses can be used to fine-tune a smaller model.

Introduction to synthetic data

Fine-tuning with Mosaic AI Model Training is a powerful technique to incorporate your enterprise data into a model to specialize for specific tasks and improve application quality. However, often times, you may not have enough data to use fine-tuning. In those cases, synthetic data generation can be applied to create more high-quality data for fine-tuning. Synthetic data is data that is generated by a model, as opposed to human written or curated data.

In synthetic data generation, you can use a high quality LLM, the 'teacher' model, to generate synthetic data that is used to fine-tune a smaller 'student' model. In this notebook, you use Meta's Llama-3.1-405B model, a high quality leading LLM, as the teacher model and Llama-3-8B as the student model.

Generation pipeline

The synthetic data generation pipeline works as follows:

1. Seed prompts

First, provide a list of 'seed' prompts to ground the data generation. For example, suppose you were building an LLM for customer support in the telecommunications domain, you can provide the following as input examples of customer support questions:

{'messages': [{'role': 'user', 'content': 'What phone plans do you offer?'}]}
{'messages': [{'role': 'user', 'content': 'Can I share minutes with my family?'}]}
{'messages': [{'role': 'user', 'content': 'How do I purchase unlimited data?'}]}
{'messages': [{'role': 'user', 'content': 'My phone plan is too expensive.'}]}

2. Prompt expansion

Next, use the teacher model to perform prompt expansion, meaning you ask the teacher model to generate more prompts in the same domain as the seed prompts above, creating a broader array of prompts that mimic real-world scenarios.

For instance, given the above seed prompts in the telecommunications customer support space, the teacher might generate a new prompt like:

{
    "messages": [
        {
            "role": "user",
            "content": "Are there any discounts for adding multiple lines to my account?"
        }
    ]
}

3. Generate responses

After prompt expansion, you can call the teacher model again to generate responses to form a dialog.

The example generated prompt from the previous step looks like the following:

{
    "messages": [
        {
            "role": "user",
            "content": "Are there any discounts for adding multiple lines to my account?"
        },
        {
            "role": "assistant",
            "content": "Yes, we definitely offer discounts for adding multiple lines. Our Family and Group Plans provide a lower cost per line when you add more lines to your account. This is an excellent way to save money while ensuring everyone in your family or group stays connected. For specific details and the most current promotions, I recommend visiting our website or contacting customer service directly."
        }
    ]
}

4. Storage and fine-tuning

Finally, save the generated data in Unity Catalog and then call the Mosaic AI Model Traing service to fine-tune the Llama3 8B Instruct model.

Getting started

To use the notebook, define the inputs to the pipeline:

  • Teacher model name
  • Location to save synthetically generated prompts as a UC table
  • Location to save the final labeled data as a UC table

Optional parameters:

  • Inference parameters for the teacher model: top_k, temperature
  • Number of responses per prompt, default: 1
  • Minimum number of prompts to synthetically generate: default: None

1. Set seed prompts and other inputs

Data locations

  • PROMPT_PATH (str): Location to a JSONL file with your seeded prompts
  • OUTPUT_UC_TABLE_PATH (str): Name of the UC table to store the final synthetically generated dataset, which includes both user prompt and assistant responses
  • SYNTHETIC_PROMPT_FILENAME (str): Name of the UC table to store the synthetically generated prompts without the final labels

Sampling parameters

  • TEACHER_MODEL (str): The name of your PAYGO or Provisioned Throughput model you want to distill data from
  • TEMPERATURE (float): Inference parameter for the teacher model; lower makes it more deterministic, higher makes the responses more random
  • TOP_P (float): Inference parameter for the teacher model

Size parameters

  • MINIMUM_NUM_PROMPTS (Optional[int]): Total number of prompts to synthetically generate. This is essentially the stopping condition for the synthetic prompt generation pipeline. If set to its default, None, prompt expansion is performed, meaning only the prompts from the input JSONL file are used as the seed
  • RESPONSES_PER_PROMPT (int): Number of responses per prompt the model generates. Default is 1 response per prompt

Helper functions

The following sections define helper functions that:

  • Load prompts from the provided file and perform validation
  • Read and write data to Unity Catalog
  • Call the model over API and prompting

Load and validate the input data

Read and write data to Unity Catalog

Call the teacher model

Prompting and XML

When you prompt the model to produce more examples, the model might say things like "Sure, I'd be happy to help you with that!", add formatting, or do any number of things that are hard to parse. For that reason, you can prompt the model to respond using XML syntax, which you can then parse with confidence.

2. Prompt expansion

Load the prompt data to Unity Catalog, and define a single prompt synthesis pipeline that you can parallelize over, then run it.

    3. Generate responses

    By this point, you have loaded the provided (seed) prompts into a standardized messages format and performed prompt expansion as needed to satisfy the MINIMUM_NUM_PROMPTS requirement.

    In this section you create the "assistant" response at the end of these message-formatted prompts by again calling the teacher model.

      Synthetic data generation complete

      At this point, your new synthetically generated training dataset is saved in the UC table OUTPUT_UC_TABLE_PATH, which you can use as the training dataset for the Mosaic AI Model Training API.

      In order to actually "distill" this data into your model, you need to launch a fine-tuning run using the dataset you just created.

      4. Start a fine-tuning run with your synthetic data

      You now have a training dataset to use for fine-tuning your student model using Mosaic AI Model Training. See the documentation (AWS | Azure)

      To train a smaller model on this new data:

      • Install the databricks-genai package
      • Specify which model to fine-tune, Llama3-8B-Instruct
      • Specify where to register the fine-tuned model for deployment
      • Specify the ID of the Spark cluster to use to prepare your Unity Catalog table for training