メインコンテンツまでスキップ

OpenAI gpt-oss-20b の分散ファインチューニング

このノートブックでは、サーバレス GPU コンピュートでの分散トレーニングを使用してOpenAIのgpt-oss-20bモデルを微調整する方法を示します。 以下の方法を学習します:

  • LoRA (低ランク適応) を適用して 20B 論点モデルを効率的に微調整する
  • MXFP4量子化を 使用してトレーニング中のメモリ要件を削減します
  • 8つのH100 GPUにわたる 分散データ並列 処理を活用
  • デプロイ用に微調整されたモデルをUnity Catalogに登録する

重要な概念:

  • gpt-oss-20b : OpenAIの 200 億件のオープンソース言語モデル
  • LoRA : ベースモデルをフリーズさせながら小さなアダプターを重ねてトレーニングする効率的なファインチューニング
  • MXFP4量子化:メモリ使用量を削減するマイクロスケーリング4ビット浮動小数点形式
  • TRL : 教師ありファインチューニング用Transformer強化学習ライブラリ
  • サーバーレス GPU コンピュート: GPU リソースを自動的にスケーリングするDatabricksマネージド コンピュート

サーバレスGPUコンピュートに接続する

このノートブックはサーバレスGPUコンピュートを必要とします。 接続するには:

  1. 右上にあるノートブックのコンピュートセレクターをクリックし、 「サーバーレス GPU」 を選択します
  2. 右側の環境ボタンをクリックします
  3. アクセラレータ として 8xH100を 選択
  4. 「適用」 をクリック

トレーニング機能は、分散トレーニング用に 8 個の H100 GPU を自動的にプロビジョニングします。

必要なライブラリをインストールする

監視付きファインチューニング用の TRL、LoRA アダプター用の PEFT、モデル追跡用のMLflowなど、分散トレーニングに必要なライブラリをインストールします。

Python
%pip install "trl>=0.20.0" "peft>=0.17.0" "transformers==4.56.1"
%pip install mlflow>=3.6
%pip install hf_transfer==0.1.9
%restart_python
Python
dbutils.library.restartPython()

Unity Catalogとモデルを構成する

Unity Catalog登録とモデルのトレーニングのための設定をセットアップします。 上記のウィジェットを使用して、これらの問題をカスタマイズできます。

  • uc_cataloguc_schemauc_model_name : モデル登録用の Unity Catalog の場所
  • uc_volume : モデルチェックポイントを保存するためのボリューム名
  • model : Hugging Face モデル識別子 (デフォルト: openai/gpt-oss-20b)
  • dataset_path : ファインチューニングに使用するデータセット (勝手: HuggingFaceH4/Multilingual-Thinking)
Python
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "gpt-oss-20b-peft")
dbutils.widgets.text("uc_volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-20b")
dbutils.widgets.text("dataset_path", "HuggingFaceH4/Multilingual-Thinking")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
HF_MODEL_NAME = dbutils.widgets.get("model")
DATASET_PATH = dbutils.widgets.get("dataset_path")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"HF_MODEL_NAME: {HF_MODEL_NAME}")
print(f"DATASET_PATH: {DATASET_PATH}")

OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")

データセットを選択する

デフォルトでは、このノートブックは、複数の言語に翻訳された思考の連鎖で特別にキュレーションされた「HuggingFaceH4/Multilingual-Thinking」を使用します。上記の「データセット パス」を編集して、別のデータセットを使用できます。

GPUメモリログユーティリティを定義する

このユーティリティ関数は、分散トレーニング中の GPU メモリ使用量を監視するのに役立ちます。各 GPU ランクに割り当てられたメモリと予約済みのメモリをログに記録します。これは、メモリの問題のデバッグに役立ちます。

Python
import os
import torch
import torch.distributed as dist

def log_gpu_memory(tag=""):
if not torch.cuda.is_available():
return

# rank info (if distributed is initialized)
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1

device = torch.cuda.current_device() # current GPU for this process
torch.cuda.synchronize(device)

allocated = torch.cuda.memory_allocated(device) / 1024**2
reserved = torch.cuda.memory_reserved(device) / 1024**2

print(
f"[{tag}] rank={rank}/{world_size-1}, "
f"device={device}, "
f"allocated={allocated:.1f} MB, reserved={reserved:.1f} MB"
)

分散トレーニング関数を定義する

次のセルは、serverless_gpu ライブラリの@distributedデコレータを使用してトレーニング関数を定義します。このデコレータ:

  • 分散トレーニング用にオンデマンドで 8 基の H100 GPU をプロビジョニング
  • 複数のGPU間でデータの並列処理を自動的に処理します

機能には以下が含まれます。

  • データセットの読み込みとトークン化
  • MXFP4量子化によるモデル初期化
  • LoRAアダプタの構成
  • 勾配チェックポイントと混合精度によるトレーニング
  • Unity Catalogボリュームへのモデルの保存
Python
from serverless_gpu import distributed

@distributed(gpus=8, gpu_type="h100")
def run_train():
import logging
import os
import torch

rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
world_size = int(os.environ.get("WORLD_SIZE", str(torch.cuda.device_count())))

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

is_main = rank == 0
if is_main:
logging.info("DDP environment")
logging.info(f"\tWORLD_SIZE={world_size} RANK={rank} LOCAL_RANK={local_rank}")
logging.info(f"\tCUDA device count (this node): {torch.cuda.device_count()}")

from datasets import load_dataset
dataset = load_dataset(DATASET_PATH, split="train")

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)

from transformers import AutoModelForCausalLM, Mxfp4Config

quantization_config = Mxfp4Config(dequantize=True)
model_kwargs = dict(
attn_implementation="eager", # Use eager attention implementation for better performance
dtype=torch.bfloat16,
quantization_config=quantization_config,
use_cache=False, # Since using gradient checkpointing
)

model = AutoModelForCausalLM.from_pretrained(HF_MODEL_NAME, **model_kwargs)

from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules="all-linear",
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, peft_config)
if is_main:
peft_model.print_trainable_parameters()

from trl import SFTConfig

training_args = SFTConfig(
learning_rate=2e-4,
num_train_epochs=1,
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=2048,
warmup_ratio=0.03,
lr_scheduler_type="cosine_with_min_lr",
lr_scheduler_kwargs={"min_lr_rate": 0.1},
output_dir=OUTPUT_DIR,
report_to="mlflow", # No reporting to avoid Gradio issues
push_to_hub=False, # Disable push to hub to avoid authentication issues
logging_dir=None, # Disable tensorboard logging
disable_tqdm=False, # Keep progress bars for monitoring
ddp_find_unused_parameters=False,
)

from trl import SFTTrainer

trainer = SFTTrainer(
model=peft_model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)
#torch.cuda.empty_cache()
#log_gpu_memory("before model training")
result = trainer.train()
#log_gpu_memory("after model loading")

if is_main:
logging.info("Training complete!")
logging.info(f"Final training loss: {result.training_loss:.4f}")
logging.info(f"Train runtime (s): {result.metrics.get('train_runtime', 'N/A')}")
logging.info(f"Samples/sec: {result.metrics.get('train_samples_per_second', 'N/A')}")
logging.info("\nSaving trained model...")
trainer.save_model(OUTPUT_DIR)
logging.info("✓ LoRA adapters saved - use with base model for inference")
tokenizer.save_pretrained(OUTPUT_DIR)
logging.info("✓ Tokenizer saved with model")
logging.info(f"\n🎉 All artifacts saved to: {OUTPUT_DIR}")

import mlflow
mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id

return mlflow_run_id

分散トレーニングを実行する

このセルは 8 つの H100 GPU でトレーニング機能を実行します。データセットのサイズとコンピュートの可用性に応じて、トレーニングには通常 30 ~ 60 分かかります。 この関数は、モデル登録用の MLflow 実行 ID を返します。

Python
run_id = run_train.distributed()[0]

Unity Catalogに登録するモデル

これで、デプロイ用のMLflowとUnity Catalog使用して微調整されたモデルを登録することができます。

重要: モデルのサイズ (20B 問題) を考慮して、登録セルを実行する前にノートブックを H100 アクセラレータに再接続してください。

登録プロセスは次のようになります。

  1. ベースモデルをロードし、微調整されたLoRAアダプタとマージする
  2. テキスト生成パイプラインを作成する
  3. Unity Catalog登録を使用してモデルをMLflowに記録する
Python
dbutils.widgets.dropdown("register_model", "False", ["True", "False"])
register_model = dbutils.widgets.get("register_model")
if register_model == "False":
dbutils.notebook.exit("Skipping model registration...")

登録確認

このセルは、 register_model問題をチェックします。 Falseに設定すると、ノートブックはモデルの登録をスキップします。この問題は、ノートブックの上部にあるウィジェットを使用して変更できます。

Python
print("\nRegistering model with MLflow and Unity Catalog...")

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

from peft import PeftModel
import mlflow
from mlflow import transformers as mlflow_transformers
import torch

torch.cuda.empty_cache()
# Load the trained model for registration
print("Loading LoRA model for registration...")
# For LoRA models, we need both base model and adapter
base_model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_NAME,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
adapter_dir = OUTPUT_DIR
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
# Merge LoRA into base and drop PEFT wrappers
merged_model = peft_model.merge_and_unload()

components = {
"model": merged_model,
"tokenizer": tokenizer,
}

# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"

print(f"Registering model as: {full_model_name}")

text_gen_pipe = pipeline(
task="text-generation",
model=peft_model,
tokenizer=tokenizer,
)

input_example = ["Hello, world!"]

with mlflow.start_run():
model_info = mlflow.transformers.log_model(
transformers_model=text_gen_pipe, # 🚨 pass the pipeline, not just the model
artifact_path="model",
input_example=input_example,
# optional: save_pretrained=False for reference-only PEFT logging
# save_pretrained=False,
)
# Start MLflow run and log model
print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")
print(f"✓ Model version: {model_info.registered_model_version}")

# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Optimization: Liger Kernels + LoRA")

多言語推論能力をテストする

微調整されたモデルは、複数の言語での思考連鎖推論を含む Multilingual-Thinking データセットでトレーニングされています。

次のセルは、この機能を次のように実証しています。

  • 推論言語をドイツ語に設定する
  • スペイン語でプロンプトを提供する (「オーストラリアの首都は何ですか?」)
  • モデルの内部推論がドイツ語で行われていることを観察する
Python
REASONING_LANGUAGE = "German"
SYSTEM_PROMPT = f"reasoning language: {REASONING_LANGUAGE}"
USER_PROMPT = "¿Cuál es el capital de Australia?" # Spanish for "What is the capital of Australia?"

messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT},
]

input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(merged_model.device)

gen_kwargs = {"max_new_tokens": 512, "do_sample": True, "temperature": 0.6, "top_p": None, "top_k": None}

output_ids = merged_model.generate(input_ids, **gen_kwargs)
response = tokenizer.batch_decode(output_ids)[0]
print(response)

次のステップ

モデルを微調整してテストしたので、次のことが可能になります。

サンプルノートブック

OpenAI gpt-oss-20b の分散ファインチューニング

ノートブックを新しいタブで開く