メインコンテンツまでスキップ

マルチ GPU サーバレス コンピュートで Axolotl を使用して Olmo3 7B を微調整する

このノートブックでは、 サーバーレス GPU コンピュートでAxolotl を使用してOlmo3 7B Instruct モデルを 微調整する方法を示します。DatabricksAxolotlは、QLoRA(量子化低ランク適応)を用いたLLMのポストトレーニングのための高性能フレームワークを提供し、マルチGPUインフラストラクチャ上での効率的なファインチューニングを可能にします。学習済みのモデルはMLflowに記録され、デプロイのためにUnity Catalogに登録されます。

必要な依存関係をインストールします

Flash アテンションをサポートする Axolotl、エクスペリメント トラッキング用のMLflow 、互換性のあるバージョンのトランスフォーマーと最適化ライブラリをインストールします。 cut-cross-entropyパッケージは、大規模言語モデルのメモリ効率の高い損失計算を提供します。

Python
%pip install -U packaging setuptools wheel ninja
%pip install mlflow>=3.6
%pip install --no-build-isolation axolotl[flash-attn]>=0.12.0
%pip install transformers==4.57.3
%pip uninstall -y awq autoawq
%pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"
dbutils.library.restartPython()

Retrieve HuggingFace ウイルス

DatabricksのシークレットからHuggingFaceの認証トークンを取得します。このトークンは、HuggingFace HubからOlmo3 7Bの基本モデルをダウンロードするために必要です。

Python
HF_TOKEN = dbutils.secrets.get(scope="sgc-nightly-notebook", key="hf_token")

トレーニングの設定

olmo3-7b-qlora.yamlの例に基づいて、Axolotlのトレーニング設定を行います。主な変更点は以下のとおりです。

  • エクスペリメント追跡のためのMLflow統合
  • チェックポイントストレージ用のUnity Catalogボリュームパス
  • SDPA(Scaled Dot より幅広いGPU互換性のためにFlash Attentionの代わりにAttentionを使用)

Unity Catalogパスを定義する

モデルのチェックポイントを保存するUnity Catalog場所を指定するウィジェットを作成します。 出力ディレクトリは、カタログ名、スキーマ名、ボリューム名、モデル名を組み合わせて完全修飾パスを作成します。

Python
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-20b")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_MODEL_NAME = dbutils.widgets.get("model")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")

OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
Python
import os
os.environ['AXOLOTL_DO_NOT_TRACK'] = '1'

テレメトリを無効にする

環境変数を設定することで、Axolotlの使用状況追跡を無効にします。

Axolotlの設定を作成する

AxolotlのDictDefault形式を使用して、完全なトレーニング構成を定義します。これには、モデル設定 (4 ビット量子化の QLoRA)、データセット構成 (Alpaca 形式)、LoRA ハイパーパラメーター (ランク 32、アルファ 16)、トレーニング ポイント (1 エポック、バッチ サイズ 2、勾配累積 4)、およびエクスペリメント トラッキングのためのMLflow統合が含まれます。

Python
from axolotl.cli.config import load_cfg
from axolotl.utils.dict import DictDefault

# Config is based on with some changes to fit GPU types
# https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/olmo3/olmo3-7b-qlora.yaml

# Axolotl provides full control and transparency over model and training configuration
config = DictDefault(
base_model="allenai/Olmo-3-7B-Instruct-SFT",
plugins=[
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin"
],
load_in_8bit=False,
load_in_4bit=True,
datasets=[
{
"path": "fozziethebeat/alpaca_messages_2k_test",
"type": "chat_template"
}
],
dataset_prepared_path="last_run_prepared",
val_set_size=0.1,
output_dir=OUTPUT_DIR,
adapter="qlora",
lora_model_dir=None,
sequence_len=2048,
sample_packing=True,
lora_r=32,
lora_alpha=16,
lora_dropout=0.05,
lora_target_linear=True,
lora_target_modules=[
"gate_proj",
"down_proj",
"up_proj",
"q_proj",
"v_proj",
"k_proj",
"o_proj"
],
wandb_project=None,
wandb_entity=None,
wandb_watch=None,
wandb_name=None,
wandb_log_model=None,
gradient_accumulation_steps=4,
micro_batch_size=2,
num_epochs=1,
optimizer="adamw_bnb_8bit",
lr_scheduler="cosine",
learning_rate=0.0002,
bf16="auto",
tf32=False,
gradient_checkpointing=True,
resume_from_checkpoint=None,
logging_steps=1,
flash_attention=False,
warmup_ratio=0.1,
evals_per_epoch=1,
saves_per_epoch=1,
# Eval dataset is too small
eval_sample_packing=False,
# Write metrics to MLflow
use_mlflow=True,
mlflow_tracking_uri="databricks",
mlflow_run_name="olmo3-7b-qlora-axolotl",
hf_mlflow_log_artifacts=False,
wandb_mode="disabled",
attn_implementation="sdpa",
sdpa_attention=True,
save_first_step=True,
device_map=None,
)
Python
from axolotl.utils import set_pytorch_cuda_alloc_conf

set_pytorch_cuda_alloc_conf()

PyTorchのCUDAメモリ割り当てを設定する

マルチGPU構成での効率的なトレーニングのために、GPUメモリ管理を最適化します。

サーバレスGPUコンピュートで分散トレーニングを実行

サーバレス GPU APIの@distributedデコレータを使用して、Axolotl トレーニング ジョブを 8 つの H100 GPU に分散します。 デコレーターはマルチ GPU オーケストレーションを処理し、手動でクラスターをセットアップすることなく、分散環境でトレーニング機能を実行できるようにします。

Python
from serverless_gpu.launcher import distributed
from serverless_gpu.compute import GPUType

@distributed(gpus=8, gpu_type=GPUType.H100)
def run_train(cfg: DictDefault):
import os
os.environ['HF_TOKEN'] = HF_TOKEN

from axolotl.common.datasets import load_datasets

# Load, parse and tokenize the datasets to be formatted with qwen3 chat template
# Drop long samples from the dataset that overflow the max sequence length

# validates the configuration
cfg = load_cfg(cfg)
dataset_meta = load_datasets(cfg=cfg)

from axolotl.train import train

# just train the first 16 steps for demo.
# This is sufficient to align the model as we've used packing to maximize the trainable samples per step.
cfg.max_steps = 16
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)

import mlflow
mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id

return mlflow_run_id
Python
result = run_train.distributed(config)

トレーニング ジョブを実行する

分散型トレーニングジョブを開始します。この関数はデータセットをロードし、構成を検証し、16 ステップのモデルをトレーニングし、追跡用のMLflow実行 ID を返します。

Python
run_id = result[0]
print(run_id)

MLflow実行IDの抽出

モデルの登録と体験追跡のためにトレーニング結果からMLflow実行 ID を取得します。

微調整したモデルをUnity Catalogに登録する

学習済みのLoRAアダプタをロードし、ベースモデルとマージし、結合されたモデルをMLflow経由でUnity Catalogに登録します。 これにより、モデルをデプロイおよび推論に利用できるようになります。

注: このステップでは、モデル チェックポイントをロードするために H100 GPU コンピュートが必要です。 より小型のGPUで実行すると、CUDAメモリ不足エラーが発生する可能性があります。

Python
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
try:
from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation
except ImportError:
from transformers.activations import NewGELUActivation, GELUTanh as PytorchGELUTanh, GELUActivation

from peft import PeftModel
import mlflow
from mlflow import transformers as mlflow_transformers
import torch

HF_MODEL_NAME = "allenai/Olmo-3-7B-Instruct-SFT"

torch.cuda.empty_cache()
# Load the trained model for registration
print("Loading LoRA model for registration...")
# For LoRA models, we need both base model and adapter
base_model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_NAME,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
adapter_dir = OUTPUT_DIR
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
# Merge LoRA into base and drop PEFT wrappers
merged_model = peft_model.merge_and_unload()

components = {
"model": merged_model,
"tokenizer": tokenizer,
}

# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"

print(f"Registering model as: {full_model_name}")

text_gen_pipe = pipeline(
task="text-generation",
model=peft_model,
tokenizer=tokenizer,
)

input_example = ["Hello, world!"]

with mlflow.start_run(run_id=run_id):
model_info = mlflow.transformers.log_model(
transformers_model=text_gen_pipe, # 🚨 pass the pipeline, not just the model
artifact_path="model",
input_example=input_example,
# optional: save_pretrained=False for reference-only PEFT logging
# save_pretrained=False,
)
# Start MLflow run and log model
print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")
print(f"✓ Model version: {model_info.registered_model_version}")

# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Optimization: Liger Kernels + LoRA")

次のステップ

サンプルノートブック

マルチ GPU サーバレス コンピュートで Axolotl を使用して Olmo3 7B を微調整する

ノートブックを新しいタブで開く