OpenAI gpt-oss-20b の分散ファインチューニング
このノートブックでは、サーバレス GPU コンピュートでの分散トレーニングを使用してOpenAIのgpt-oss-20bモデルを微調整する方法を示します。 以下の方法を学習します:
- LoRA (低ランク適応) を適用して 20B 論点モデルを効率的に微調整する
- MXFP4量子化を 使用してトレーニング中のメモリ要件を削減します
- 8つのH100 GPUにわたる 分散データ並列 処理を活用
- デプロイ用に微調整されたモデルをUnity Catalogに登録する
重要な概念:
- gpt-oss-20b : OpenAIの 200 億件のオープンソース言語モデル
- LoRA : ベースモデルをフリーズさせながら小さなアダプターを重ねてトレーニングする効率的なファインチューニング
- MXFP4量子化:メモリ使用量を削減するマイクロスケーリング4ビット浮動小数点形式
- TRL : 教師ありファインチューニング用Transformer強化学習ライブラリ
- サーバーレス GPU コンピュート: GPU リソースを自動的にスケーリングするDatabricksマネージド コンピュート
サーバレスGPUコンピュートに接続する
このノートブックはサーバレスGPUコンピュートを必要とします。 接続するには:
- 右上にあるノートブックのコンピュートセレクターをクリックし、 「サーバーレス GPU」 を選択します
- 右側の環境ボタンをクリックします
- アクセラレータ として 8xH100を 選択
- 「適用」 をクリック
トレーニング機能は、分散トレーニング用に 8 個の H100 GPU を自動的にプロビジョニングします。
必要なライブラリをインストールする
監視付きファインチューニング用の TRL、LoRA アダプター用の PEFT、モデル追跡用のMLflowなど、分散トレーニングに必要なライブラリをインストールします。
%pip install "trl>=0.20.0" "peft>=0.17.0" "transformers==4.56.1"
%pip install mlflow>=3.6
%pip install hf_transfer==0.1.9
%restart_python
dbutils.library.restartPython()
Unity Catalogとモデルを構成する
Unity Catalog登録とモデルのトレーニングのための設定をセットアップします。 上記のウィジェットを使用して、これらの問題をカスタマイズできます。
- uc_catalog 、 uc_schema 、 uc_model_name : モデル登録用の Unity Catalog の場所
- uc_volume : モデルチェックポイントを保存するためのボリューム名
- model : Hugging Face モデル識別子 (デフォルト: openai/gpt-oss-20b)
- dataset_path : ファインチューニングに使用するデータセット (勝手: HuggingFaceH4/Multilingual-Thinking)
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "gpt-oss-20b-peft")
dbutils.widgets.text("uc_volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-20b")
dbutils.widgets.text("dataset_path", "HuggingFaceH4/Multilingual-Thinking")
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
HF_MODEL_NAME = dbutils.widgets.get("model")
DATASET_PATH = dbutils.widgets.get("dataset_path")
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"HF_MODEL_NAME: {HF_MODEL_NAME}")
print(f"DATASET_PATH: {DATASET_PATH}")
OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
データセットを選択する
デフォルトでは、このノートブックは、複数の言語に翻訳された思考の連鎖で特別にキュレーションされた「HuggingFaceH4/Multilingual-Thinking」を使用します。上記の「データセット パス」を編集して、別のデータセットを使用できます。
GPUメモリログユーティリティを定義する
このユーティリティ関数は、分散トレーニング中の GPU メモリ使用量を監視するのに役立ちます。各 GPU ランクに割り当てられたメモリと予約済みのメモリをログに記録します。これは、メモリの問題のデバッグに役立ちます。
import os
import torch
import torch.distributed as dist
def log_gpu_memory(tag=""):
if not torch.cuda.is_available():
return
# rank info (if distributed is initialized)
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
device = torch.cuda.current_device() # current GPU for this process
torch.cuda.synchronize(device)
allocated = torch.cuda.memory_allocated(device) / 1024**2
reserved = torch.cuda.memory_reserved(device) / 1024**2
print(
f"[{tag}] rank={rank}/{world_size-1}, "
f"device={device}, "
f"allocated={allocated:.1f} MB, reserved={reserved:.1f} MB"
)
分散トレーニング関数を定義する
次のセルは、serverless_gpu ライブラリの@distributedデコレータを使用してトレーニング関数を定義します。このデコレータ:
- 分散トレーニング用にオンデマンドで 8 基の H100 GPU をプロビジョニング
- 複数のGPU間でデータの並列処理を自動的に処理します
機能には以下が含まれます。
- データセットの読み込みとトークン化
- MXFP4量子化によるモデル初期化
- LoRAアダプタの構成
- 勾配チェックポイントと混合精度によるトレーニング
- Unity Catalogボリュームへのモデルの保存
from serverless_gpu import distributed
@distributed(gpus=8, gpu_type="h100")
def run_train():
import logging
import os
import torch
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
world_size = int(os.environ.get("WORLD_SIZE", str(torch.cuda.device_count())))
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
is_main = rank == 0
if is_main:
logging.info("DDP environment")
logging.info(f"\tWORLD_SIZE={world_size} RANK={rank} LOCAL_RANK={local_rank}")
logging.info(f"\tCUDA device count (this node): {torch.cuda.device_count()}")
from datasets import load_dataset
dataset = load_dataset(DATASET_PATH, split="train")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
from transformers import AutoModelForCausalLM, Mxfp4Config
quantization_config = Mxfp4Config(dequantize=True)
model_kwargs = dict(
attn_implementation="eager", # Use eager attention implementation for better performance
dtype=torch.bfloat16,
quantization_config=quantization_config,
use_cache=False, # Since using gradient checkpointing
)
model = AutoModelForCausalLM.from_pretrained(HF_MODEL_NAME, **model_kwargs)
from peft import LoraConfig, get_peft_model
peft_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules="all-linear",
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, peft_config)
if is_main:
peft_model.print_trainable_parameters()
from trl import SFTConfig
training_args = SFTConfig(
learning_rate=2e-4,
num_train_epochs=1,
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=2048,
warmup_ratio=0.03,
lr_scheduler_type="cosine_with_min_lr",
lr_scheduler_kwargs={"min_lr_rate": 0.1},
output_dir=OUTPUT_DIR,
report_to="mlflow", # No reporting to avoid Gradio issues
push_to_hub=False, # Disable push to hub to avoid authentication issues
logging_dir=None, # Disable tensorboard logging
disable_tqdm=False, # Keep progress bars for monitoring
ddp_find_unused_parameters=False,
)
from trl import SFTTrainer
trainer = SFTTrainer(
model=peft_model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)
#torch.cuda.empty_cache()
#log_gpu_memory("before model training")
result = trainer.train()
#log_gpu_memory("after model loading")
if is_main:
logging.info("Training complete!")
logging.info(f"Final training loss: {result.training_loss:.4f}")
logging.info(f"Train runtime (s): {result.metrics.get('train_runtime', 'N/A')}")
logging.info(f"Samples/sec: {result.metrics.get('train_samples_per_second', 'N/A')}")
logging.info("\nSaving trained model...")
trainer.save_model(OUTPUT_DIR)
logging.info("✓ LoRA adapters saved - use with base model for inference")
tokenizer.save_pretrained(OUTPUT_DIR)
logging.info("✓ Tokenizer saved with model")
logging.info(f"\n🎉 All artifacts saved to: {OUTPUT_DIR}")
import mlflow
mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id
return mlflow_run_id
分散トレーニングを実行する
このセルは 8 つの H100 GPU でトレーニング機能を実行します。データセットのサイズとコンピュートの可用性に応じて、トレーニングには通常 30 ~ 60 分かかります。 この関数は、モデル登録用の MLflow 実行 ID を返します。
run_id = run_train.distributed()[0]
Unity Catalogに登録するモデル
これで、デプロイ用のMLflowとUnity Catalog使用して微調整されたモデルを登録することができます。
重要: モデルのサイズ (20B 問題) を考慮して、登録セルを実行する前にノートブックを H100 アクセラレータに再接続してください。
登録プロセスは次のようになります。
- ベースモデルをロードし、微調整されたLoRAアダプタとマージする
- テキスト生成パイプラインを作成する
- Unity Catalog登録を使用してモデルをMLflowに記録する
dbutils.widgets.dropdown("register_model", "False", ["True", "False"])
register_model = dbutils.widgets.get("register_model")
if register_model == "False":
dbutils.notebook.exit("Skipping model registration...")
登録確認
このセルは、 register_model問題をチェックします。 Falseに設定すると、ノートブックはモデルの登録をスキップします。この問題は、ノートブックの上部にあるウィジェットを使用して変更できます。
print("\nRegistering model with MLflow and Unity Catalog...")
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
import mlflow
from mlflow import transformers as mlflow_transformers
import torch
torch.cuda.empty_cache()
# Load the trained model for registration
print("Loading LoRA model for registration...")
# For LoRA models, we need both base model and adapter
base_model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_NAME,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
adapter_dir = OUTPUT_DIR
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
# Merge LoRA into base and drop PEFT wrappers
merged_model = peft_model.merge_and_unload()
components = {
"model": merged_model,
"tokenizer": tokenizer,
}
# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"
print(f"Registering model as: {full_model_name}")
text_gen_pipe = pipeline(
task="text-generation",
model=peft_model,
tokenizer=tokenizer,
)
input_example = ["Hello, world!"]
with mlflow.start_run():
model_info = mlflow.transformers.log_model(
transformers_model=text_gen_pipe, # 🚨 pass the pipeline, not just the model
artifact_path="model",
input_example=input_example,
# optional: save_pretrained=False for reference-only PEFT logging
# save_pretrained=False,
)
# Start MLflow run and log model
print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")
print(f"✓ Model version: {model_info.registered_model_version}")
# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Optimization: Liger Kernels + LoRA")
多言語推論能力をテストする
微調整されたモデルは、複数の言語での思考連鎖推論を含む Multilingual-Thinking データセットでトレーニングされています。
次のセルは、この機能を次のように実証しています。
- 推論言語をドイツ語に設定する
- スペイン語でプロンプトを提供する (「オーストラリアの首都は何ですか?」)
- モデルの内部推論がドイツ語で行われていることを観察する
REASONING_LANGUAGE = "German"
SYSTEM_PROMPT = f"reasoning language: {REASONING_LANGUAGE}"
USER_PROMPT = "¿Cuál es el capital de Australia?" # Spanish for "What is the capital of Australia?"
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(merged_model.device)
gen_kwargs = {"max_new_tokens": 512, "do_sample": True, "temperature": 0.6, "top_p": None, "top_k": None}
output_ids = merged_model.generate(input_ids, **gen_kwargs)
response = tokenizer.batch_decode(output_ids)[0]
print(response)
次のステップ
モデルを微調整してテストしたので、次のことが可能になります。
- モデルのデプロイ :モデルサービングでモデルを提供します
- 分散トレーニングの詳細 :マルチGPUおよびマルチノード分散トレーニング
- サーバレス GPU の使用を最適化する :サーバレス GPU コンピュートのベスト プラクティス
- 問題のトラブルシューティング :サーバレス GPU コンピュートの問題のトラブルシューティング
- OpenAIの gpt-oss について学ぶ : OpenAIファインチューニング クックブック