Ajuste fino distribuído do Llama-3.2-3B com Unsloth em múltiplas GPUs A10
Este caderno demonstra como ajustar o Llama-3.2-3B. Modelo de linguagem de grande porte usando treinamento distribuído em várias GPUs A10. Ele combina a biblioteca Unsloth para ajuste fino otimizado e eficiente de parâmetros com a biblioteca serverless_gpu para orquestração de treinamento distribuído.
O livro "The Notebook" aborda os seguintes temas:
- Configurando o treinamento distribuído em 4 GPUs A10
- Carregando e ajustando o Llama-3.2-3B com adaptadores LoRa
- Processamento de dados de treinamento do datasetFineTome-100k
- treinamento com ajuste fino supervisionado (SFT) e acompanhamento MLflow
- Mesclando adaptadores e registrando o modelo no Unity Catalog
O treinamento distribuído reduz significativamente o tempo de treinamento, paralelizando a computação em várias GPUs, mantendo a qualidade do modelo.
Requisitos: compute GPU sem servidor com aceleradores A10
Este notebook requer compute em GPU com aceleradores A10. Selecione A10 como acelerador no painel de ambiente e clique em Aplicar .
Observação: o provisionamento de recursos computacionais pode levar até 8 minutos. O treinamento distribuído provisionará automaticamente 4 GPUs A10 quando a função de treinamento for executada.
Instale a biblioteca necessária.
Instale a biblioteca Unsloth com suporte CUDA 12.4 e PyTorch 2.6.0, juntamente com o accelerate para treinamento distribuído, o unsloth_zoo para utilidades adicionais e MLflow para acompanhamento de experimentos. O ambiente de execução Python é reiniciado após a instalação para carregar o novo pacote.
%pip install unsloth[cu124-torch260]==2025.9.6
%pip install accelerate==1.7.0
%pip install unsloth_zoo==2025.9.8
%pip install mlflow>=3.6
%restart_python
Configure Unity Catalog e as configurações do modelo.
Defina os locais Unity Catalog e a configuração do modelo usando widgets do Notebook para facilitar a personalização. A configuração inclui:
- Espaço de nomes Unity Catalog (catálogo, esquema, nome do modelo, volume)
- Seleção do modelo base (Llama-3.2-3B-Instruct) (de Unsloth)
- Diretório de saída para salvar pontos de verificação em volumes Unity Catalog
- dataset de treinamento (FineTome-100k)
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "llama-3_2-3b")
dbutils.widgets.text("uc_volume", "checkpoints")
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
# Model selection - Choose based on your compute constraints
MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct" # or choose "unsloth/Llama-3.2-1B-Instruct"
OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}" # Save checkpoint to UC Volume
DATASET_NAME = "mlabonne/FineTome-100k"
print(f"MODEL_NAME: {MODEL_NAME}")
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
print(f"DATASET_NAME: {DATASET_NAME}")
Defina a função de treinamento distribuído
Crie uma função de treinamento decorada com @distributed(gpus=4, gpu_type='a10', remote=True) para habilitar o treinamento multi-GPU. Esta função engloba todo o fluxo de trabalho de treinamento:
- Carregando a base Llama-3.2-3B modelo e tokenizador
- Aplicação de adaptadores LoRa para ajuste fino com otimização de parâmetros
- Processando o dataset FineTome-100k com o padrão de chat
- Configurando o SFTTrainer com configurações de treinamento distribuído
- treinamento do modelo com acompanhamento MLflow
- Salvando os adaptadores e o tokenizador treinados em volumes Unity Catalog
O decorador @distributed lida automaticamente com o provisionamento de GPU e a orquestração de treinamento distribuído em 4 GPUs A10.
from serverless_gpu import distributed
from serverless_gpu import runtime as rt
@distributed(gpus=4, gpu_type='a10', remote=True)
def run_train():
from datasets import load_dataset
import logging
import mlflow
import torch
# IMPORTANT: import unsloth BEFORE trl
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from transformers.integrations import MLflowCallback
max_seq_length = 2048 # Choose any!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False.
# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
model = FastLanguageModel.get_peft_model(
model,
r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha=16,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
)
# Process data
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }
dataset = load_dataset(DATASET_NAME, split="train")
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched=True,)
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 6,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
# num_train_epochs = 1, # Set this for 1 full training run.
max_steps = 25,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = OUTPUT_DIR,
report_to = "mlflow", # Use MLflow to track model metrics,
run_name = f"{MODEL_NAME}-finetune-unsloth",
),
)
trainer = train_on_responses_only(
trainer,
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
num_proc=1
)
trainer.train()
# Save model
if rt.get_global_rank() == 0:
logging.info("\nSaving trained model...")
trainer.save_model(OUTPUT_DIR)
logging.info("✓ LoRA adapters saved - use with base model for inference")
tokenizer.save_pretrained(OUTPUT_DIR)
logging.info("✓ Tokenizer saved with model")
logging.info(f"\n🎉 All artifacts saved to: {OUTPUT_DIR}")
mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id
return mlflow_run_id
Executar o treinamento distribuído
Inicie a função de treinamento distribuído em 4 GPUs A10. O método .distributed() provisiona as GPUs, distribui a carga de treinamento e retorna o ID de execução MLflow para acompanhamento. Esta etapa pode levar vários minutos, pois o provisionamento compute o recurso e a execução do ciclo de treinamento.
run_id = run_train.distributed()[0]
Estratégia de registro e implantação de modelos
Após a conclusão do treinamento distribuído, registre o modelo ajustado para uso em produção:
- AcompanhamentoMLflow - modelos de artefatos registrados, treinamentos detalhados e metadados para acompanhamento de experimentos
- Unity Catalog - registro do modelo para governança centralizada, controle de acesso e acompanhamento de linhagem
- Versionamento de Modelos - O versionamento automático permite o gerenciamento do ciclo de vida do modelo e recursos de reversão.
- Metadados - Informações completas do modelo garantem reprodutibilidade e compliance
Mesclar adaptadores e registro no Unity Catalog
Carregue os adaptadores LoRa treinados e merge os com a base Llama-3.2-3B. Pesos do modelo e registro do modelo final no Unity Catalog. Este processo:
- Carrega o modelo base e os adaptadores LoRa treinados a partir do diretório de checkpoints.
- Mescle os pesos do adaptador no modelo base para criar um único modelo implantável.
- Registra o modelo de mesclagem no MLflow com os metadados apropriados.
- Registre o modelo no Unity Catalog para fins de governança e implantação.
O modelo registrado está pronto para ser implantado no endpoint do modelo de serviço.
print("\nRegistering model with MLflow and Unity Catalog...")
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import mlflow
from mlflow import transformers as mlflow_transformers
# Load the trained model for registration
print("Loading LoRA model for registration...")
# For LoRA models, we need both base model and adapter
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
adapter_dir = OUTPUT_DIR
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
# Merge LoRA into base and drop PEFT wrappers
merged_model = peft_model.merge_and_unload()
components = {
"model": merged_model,
"tokenizer": tokenizer,
}
# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"
print(f"Registering model as: {full_model_name}")
# Start MLflow run and log model
task = "llm/v1/chat"
with mlflow.start_run(run_id=run_id):
model_info = mlflow.transformers.log_model(
transformers_model=components,
name="model",
task=task,
registered_model_name=full_model_name,
metadata={
"task": task,
"pretrained_model_name": MODEL_NAME,
"databricks_model_family": "Llama3.2",
},
)
print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")
print(f"✓ Model version: {model_info.registered_model_version}")
# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Optimization: Liger Kernels + LoRA")
Próximos passos
O modelo ajustado já está registrado no Unity Catalog e pronto para implantação. Saiba mais sobre treinamento distribuído e modelo de instalação:
- modelos implantados para inferência de lotes ou tempo real
- Crie e gerencie endpoints de serviço de modelos
- Documentação Unsloth