対照学習を使用して埋め込みモデルを微調整する
このノートブックでは、対照学習を使用して、サーバレス GPU コンピュート上で BERT スタイルの埋め込みモデルを微調整する方法を示します。 gte-large-en-v1.5モデルを使用し、MosaicML Composer トレーナーを使用して単一の A10G GPU でトレーニングして、チェックポイントを保存し、トレーニングを再開し、結果をMLflowに記録します。
埋め込みモデルは、ベクター データベースや検索拡張生成 (RAG) アプリケーションで広く使用されています。カスタム データに埋め込みモデルをファインチューニングすることは、特定のドメインの検索精度を向上させる強力な方法です。
このノートブックでは、次の方法を学習します。
- 依存関係をインストールして環境を構成する
- Delta テーブルからトレーニング データをロードする
- データをモザイクストリーミングデータセット(MDS)形式に変換する
- モデル、オプティマイザー、トレーニングを構成する
- バッチ内ネガティブデータを用いた対照学習を使用してモデルをトレーニングする
- MLflowと登録するモデルの体験をUnity Catalogに追跡
- パフォーマンスを確認し、微調整されたモデルを提供する
この例では、 MS Marcoデータセットの前処理済みバージョンを使用していますが、独自のデータで動作するように適応させることができます。
対照学習の概要
対照学習で は、対照損失を使用して、類似のインスタンスが潜在空間内でより接近しているデータ表現を学習します。埋め込みモデルの場合、これは、「アフェンピンシャーとは何か」と「かわいいダックスフント」のようなクエリを、「アフェンピンシャーとは何か」と「dbsql の使い方は?」のようなクエリよりも類似したものとして扱うことを意味します。モデルは クエリ を肯定的な文章 (関連するテキスト) と 否定的な文章 (無関係なテキスト) と比較して、意味の違いを学習します。
否定的な文章を選択する 2 つのアプローチ:
-
バッチ内ネガティブ : ネガティブなパッセージはバッチ内からランダムに選択されます。特定のクエリとパッセージのペアの場合、バッチ内の他のすべてのパッセージは負の例になります。バッチ サイズが 8 の場合、クエリごとに 7 つの否定的なパッセージと 1 つの肯定的なパッセージが取得されます。 バッチ サイズが大きいほど、負の例が多くなり 、このアプローチがより効果的になります。
-
ハードネガティブ : 意味的に難しい事前定義されたネガティブな文章。クエリに関連している可能性はありますが、少し間違っているか無関係です。
llm-foundryコードは、より高度なファインチューニングのためのハード ネガをサポートしています。
このノートブックはバッチネガを使用しています。 データに否定的な文章が提供されていない場合、 llm-foundry コードは、他のクエリからの肯定的な文章を否定的な文章として扱うことで、自動的にそれらを推測します。
要件
このノートブックを実行する前に、いくつかの設定を行って、サーバレス GPU コンピュートに接続する必要があります。
ノートブックを構成する
このノートブックはクエリー (ウィジェット) を使用してパスと設定を構成します。 実行する前に、次の点を更新してください。
catalog: Unity Catalog カタログ名 (例:main)schema: Unity Catalogスキーマ名train_delta_table: Deltaテーブル名 (カタログ/スキーマ接頭辞なし)val_delta_table: 検証Deltaテーブル名(オプション)uc_checkpoint_folder: チェックポイントのUnity Catalogボリュームフォルダ名register_to: Unity Catalog登録用のモデル名experiment_name: MLflow体験パス (形式:/Users/<username>/<run_name>)
サーバレスGPUコンピュートに接続する
このノートブックには 1 つの A10G GPU が必要です。
- ノートブックの上部にある [接続 ] ドロップダウンをクリックします。
- サーバレス GPU を選択します。
- ノートブックの右側にある 環境 サイドパネルを開きます。
- このデモでは、 アクセラレータ を A10 に設定します。
- [適用] を選択し、 [確認] をクリックして、この環境をノートブックに適用します。
詳細については、 「サーバレス GPU コンピュート」を参照してください。
依存関係をインストールする
まず、必要なライブラリをすべてインストールし、環境の準備ができていることを確認します。
%pip install llm-foundry[gpu]==0.20.0
%pip uninstall flash_attn -y
%pip install transformers==4.46.0
%pip install hf_transfer
%restart_python
環境変数を設定する
分散トレーニングと一時ファイルストレージの環境変数を構成します。
import os
import tempfile
import mlflow
os.environ["TMPDIR"] = os.path.join(os.getcwd(), tempfile.mkdtemp())
os.environ["NCCL_DEBUG"] = "WARN"
os.environ["WORLD_SIZE"] = "1"
設定ウィジェットを作成する
ノートブックの入力ウィジェットを作成します。 続行する前に、ノートブックの上部にあるウィジェット パネルにこれらの値を入力します。
# Create widgets for configuration
dbutils.widgets.text(
"train_delta_table", "ms_marco_v_1_1_train_processed", "Training Delta Table Name"
)
dbutils.widgets.text(
"val_delta_table", "ms_marco_v_1_1_val_processed", "Validation Delta Table Name"
)
dbutils.widgets.text("register_to", "sgc_ft_embedding", "Model Registry Path")
dbutils.widgets.text("experiment_name", "/Users/<EMAIL>/Embedding_finetuning", "MLflow Experiment Name")
dbutils.widgets.text("uc_checkpoint_folder", "checkpoints", "UC Checkpoint Folder")
dbutils.widgets.text("catalog", "main", "catalog")
dbutils.widgets.text("schema", "default", "schema")
# Validate widget inputs
assert dbutils.widgets.get("train_delta_table")
assert dbutils.widgets.get("register_to")
assert dbutils.widgets.get("experiment_name")
assert dbutils.widgets.get("catalog")
assert dbutils.widgets.get("schema")
assert dbutils.widgets.get("uc_checkpoint_folder")
# Build env paths
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
train_delta_table = dbutils.widgets.get("train_delta_table")
val_delta_table = dbutils.widgets.get("val_delta_table")
uc_checkpoint_folder = dbutils.widgets.get("uc_checkpoint_folder")
register_to = dbutils.widgets.get("register_to")
experiment_name = dbutils.widgets.get("experiment_name")
train_delta_table = f"{catalog}.{schema}.{train_delta_table}"
val_delta_table = f"{catalog}.{schema}.{val_delta_table}" if val_delta_table else None
uc_checkpoint_path = f"{catalog}.{schema}.{uc_checkpoint_folder}"
register_to_path = f"{catalog}.{schema}.{register_to}"
Delta テーブルからトレーニング データをロードする
Delta テーブルからトレーニング データと検証データを読み込みます。データには次の列が必要です。
query_text: クエリまたは質問のテキストpositive_passage: クエリに関連するテキスト部分negative_passages(オプション): 事前に定義されたハードネガティブ文章の配列
この例では、 MS Marcoデータセットの前処理済みバージョンを使用します。negative_passagesを指定しない場合、トレーニング コードはバッチ内のネガティブを自動的に使用します。
df_train = spark.table(train_delta_table)
df_val = None
if val_delta_table:
df_val = spark.table(val_delta_table)
MODEL_REGISTRY_PREFIX = f"{catalog}.{schema}"
REGISTERED_MODEL_NAME = register_to
EXPERIMENT_NAME = experiment_name
UC_CHECKPOINT_PATH = f"{catalog}.{schema}.{uc_checkpoint_folder}"
Databricks の資格情報を構成する
トレーニング中に Unity Catalog および MLflow にアクセスするための Databricks CLI 資格情報を設定します。この構成ファイルにより、トレーニング プロセスが Databricks サービスで認証できるようになります。
%sh echo -e "[DEFAULT]\nhost=$DATABRICKS_HOST\ntoken=$DATABRICKS_TOKEN" > ~/.databrickscfg
データをモザイクストリーミングデータセット形式に変換する
Deltaテーブルをモザイク ストリーミング データセット (MDS) 形式に変換します。これは、サーバレス GPU コンピュートでの分散トレーニング用に最適化されています。 MDS は以下を提供します:
- トレーニング中のデータ読み込みの高速化
- 効率的な圧縮と保存
- Composerトレーナーとのシームレスな統合
変換機能は必要なスキーマ変換を処理し、MDS ファイルをUnity Catalogボリュームに保存します。 データに異なる列名がある場合は、それに応じてconvert_x関数を更新してください。
詳細については、 StreamingDataset のドキュメントを参照してください。
import os
import gc
from streaming import MDSWriter, StreamingDataset
from pyspark.sql import DataFrame
import json
import warnings
warnings.filterwarnings("ignore", module="threadpoolctl")
def process_embedding_data(
df: DataFrame,
output_path: str,
compression: str,
hashes: list[str],
limit: str,
):
def convert_x(x: dict) -> dict:
return {
"query_text": x["query_text"],
"positive_passage": x["positive_passage"],
"negative_passages": (
json.dumps(x["negative_passages"])
if x.get("negative_passages") is not None
else "[]"
),
}
try:
dtypes = {
"query_text": "str",
"positive_passage": "str",
"negative_passages": "str",
}
print(f"Starting conversion to MDS at {output_path}")
# Clear memory before processing
gc.collect()
row_count = 0
with MDSWriter(
out=output_path,
columns=dtypes,
compression=compression,
hashes=hashes,
size_limit=limit,
) as out:
for row in df.toLocalIterator():
record = convert_x(row.asDict())
out.write(record)
row_count += 1
if row_count % 10000 == 0:
print(f"Processed {row_count} records...")
print(f"Successfully wrote {row_count} records to {output_path}")
except Exception as e:
print(f"Error during data conversion: {e}")
raise e
compression = "zstd:7"
hashes = ["sha1"]
limit = "10mb"
import os
# FUSE path for checking existence
uc_train_folder = f"/Volumes/{catalog}/{schema}/embedding_temp_data/train"
uc_val_folder = f"/Volumes/{catalog}/{schema}/embedding_temp_data/val"
# dbfs path for MDSWriter and StreamingDataset
train_folder = f"dbfs:/Volumes/{catalog}/{schema}/embedding_temp_data/train"
val_folder = f"dbfs:/Volumes/{catalog}/{schema}/embedding_temp_data/val"
train_index = os.path.join(uc_train_folder, "index.json")
val_index = os.path.join(uc_val_folder, "index.json")
if os.path.exists(train_index):
print("Train MDS data already exists, skipping conversion.")
else:
process_embedding_data(df_train, train_folder, compression, hashes, limit)
if df_val is not None:
if os.path.exists(val_index):
print("Validation MDS data already exists, skipping conversion.")
else:
process_embedding_data(df_val, val_folder, compression, hashes, limit)
埋め込みモデルを構成する
ファインチューニング用のモデルとトークナイザー構成を定義します。 この例では、Hugging Face のgte-large-en-v1.5モデルを使用します。
主要な構成:
-
temperature: 対照損失における類似度スコアをスケーリングするハイパーパラメータ (0-1)。損失値が極端に高いか低い場合はこれを調整します(デフォルト: 0.5) -
pos_step_size: 負のサンプリングの位置ステップ サイズ。バッチ内のネガティブの場合は2に設定し、事前定義されたネガティブを使用する場合は1 + number of hard negativesに設定します。 -
vector_representation: 埋め込みの表現方法avg: 平均トークン埋め込み(ほとんどのモデルに推奨)eos: シーケンス終了トークン埋め込みを使用する
-
gather_in_batch_negatives: バッチ内ネガティブの場合はtrueに設定し、事前定義されたハードネガティブの場合はfalseに設定します -
pretrained_model_name_or_path: Hugging Faceモデル識別子
トークナイザー構成では、モデルの最大シーケンス長と特殊トークンを設定します。
model_cfg = {
"name": "finetune_embedding_model",
"trust_remote_code": True,
"contrastive_config": {
"temperature": 0.5,
"pos_step_size": 2, # set to 2 when not using predefined hard negatives. Otherwise use 1 + number of hard negatives
"normalize_output": True,
"vector_representation": "avg", # or eos, depending on the model default
"gather_in_batch_negatives": True, # set to true when not using predefined hard negatives
},
"pretrained_model_name_or_path": "Alibaba-NLP/gte-large-en-v1.5",
"loss_fn": "torch_crossentropy"
}
tokenizer_cfg = {
"name": "Alibaba-NLP/gte-large-en-v1.5",
"kwargs": {
"eos_token": "</s>", # this is the standard eos token for gte-large-en-v1.5
"model_max_length": 128,
"trust_remote_code": True,
},
}
MLflow ログを構成する
MLflowログを設定して、トレーニング メトリクス、論点、アーティファクトを追跡します。 ロガー構成では以下を指定します。
- 実行が記録されるMLflowエクスペリメント
- モデルレジストリの保存先としての Unity Catalog
- モデル登録用のカタログとスキーマのプレフィックス
logger_cfg = {
"mlflow": {
"run_name": "finetune_embedding",
"tracking_uri": "databricks",
"experiment_name": EXPERIMENT_NAME,
"model_registry_uri": "databricks-uc",
"model_registry_prefix": MODEL_REGISTRY_PREFIX,
}
}
トレーニングコールバックを構成する
コールバックはトレーニング プロセスのさまざまな側面を制御します。最も重要なコールバックはhf_checkpointerです。
- Hugging Face互換のチェックポイントをUnity Catalogに保存します
- 提供するモデルをUnity Catalogに登録する
- プロビジョニングされたスループットサービングのモデルメタデータを構成する
- 定期的にチェックポイントを保存します(この例では1時間ごと)
その他のコールバックは、学習率やメモリ使用量を監視し、ガベージ コレクションを実行して GPU メモリを最適化します。
callback_cfg = {
"lr_monitor": {},
"scheduled_gc": {"batch_interval": 1000},
"memory_monitor": {},
"hf_checkpointer": {
"precision": "bfloat16",
"save_folder": UC_CHECKPOINT_PATH,
"save_interval": "1h",
"mlflow_logging_config": {
"task": "llm/v1/embeddings",
"metadata": {
"task": "llm/v1/embeddings",
"source": "huggingface",
"pretrained_model_name": "Alibaba-NLP/gte-large-en-v1.5",
"databricks_model_family": "NewModel (gte_v1_5)",
"databricks_model_size_parameters": "434m",
},
},
"mlflow_registered_model_name": REGISTERED_MODEL_NAME,
},
}
トレーニングハイパーパラメータを設定する
オプティマイザー、学習率スケジューラー、精度、およびトレーニング アルゴリズムを定義します。これらは、データと要件に基づいて調整できる標準的な機械学習トレーニングの問題です。
- 最適化装置 : 分離重み減衰を備えたAdamW
- 学習率 : コサインウォームアップスケジュールで3e-5
- 精度 : より高速なトレーニングのために bfloat16 を使用した自動混合精度
- 勾配クリッピング : トレーニング中に勾配が爆発するのを防ぐ
optimizer_cfg = {
"lr": 0.00003,
"eps": 1.0e-08,
"name": "decoupled_adamw",
"betas": [0.9, 0.95],
"weight_decay": 0.0001,
}
precision_cfg = "amp_bf16"
scheduler_cfg = {"name": "cosine_with_warmup", "alpha_f": 0.02, "t_warmup": "0.06dur"}
algorithms_cfg = {
"gradient_clipping": {"clipping_type": "norm", "clipping_threshold": 1}
}
データローダーを構成する
トレーニング中にトレーニング データと評価データをロードする方法を定義します。データローダー:
- Unity Catalogボリューム内のMDS形式のデータを指す
- テキストの前処理を設定する(「query:」および「passage:」プレフィックスを先頭に追加する)
- 最大シーケンス長とバッチ処理を設定する
- トレーニングの収束性を向上させるためにシャッフルを有効にする
remoteパスが変換された MDS データの場所を指していることを確認します。
train_loader = {
"name": "contrastive_pairs",
"dataset": {
"local": None,
"split": None,
"remote": train_folder,
"shuffle": True,
"max_seq_len": 128,
"shuffle_seed": 42,
"prepend_query": "query: ",
"prepend_passage": "passage: ",
"append_eos_token": True,
},
"drop_last": True,
"num_workers": 8,
}
eval_loader = {
"name": "contrastive_pairs",
"dataset": {
"local": None,
"split": None,
"remote": val_folder,
"shuffle": True,
"max_seq_len": 128,
"shuffle_seed": 42,
"prepend_query": "query: ",
"prepend_passage": "passage: ",
"append_eos_token": True,
},
"drop_last": True,
"num_workers": 8,
}
完全なトレーニング構成を組み立てる
すべての構成コンポーネントを 1 つのトレーニング構成オブジェクトに結合します。これには、モデル、トークナイザー、データ ローダー、オプティマイザー、コールバック、トレーニング パラメーターが含まれます。
from omegaconf import DictConfig, OmegaConf
config = {
"seed": 42,
"max_seq_len": 128,
"model": model_cfg,
"tokenizer": tokenizer_cfg,
"loggers": logger_cfg,
"callbacks": callback_cfg,
"run_name": "finetune-BERT",
"optimizer": optimizer_cfg,
"precision": precision_cfg,
"scheduler": scheduler_cfg,
"algorithms": algorithms_cfg,
"train_loader": train_loader,
"eval_loader": eval_loader,
"eval_first": True,
"save_folder": 'dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}/artifacts/{run_name}/checkpoints',
"max_duration": "200ba",
"progress_bar": False,
"eval_interval": "50ba",
"save_interval": "50ba",
"log_to_console": True,
"load_weights_only": True,
"console_log_interval": "1ba",
"device_eval_batch_size": 1,
"eval_subset_num_batches": 4,
"global_train_batch_size": 1,
"device_train_microbatch_size": 1
}
cfg = DictConfig(config)
埋め込みモデルをトレーニングする
埋め込みモデルに最適化されたトレーニングを提供する MosaicML LLM-Foundry ライブラリを使用してトレーニング プロセスを実行します。トレーニング機能:
@distributedデコレータを使用して GPU リソースをプロビジョニングします- 50バッチごとの評価付きで200バッチトレーニングする
- チェックポイントをUnity Catalogに保存します
- メトリクスとアーティファクトをMLflowに記録します
- 最終モデルをUnity Catalogに登録する
データセットのサイズによっては、トレーニングに数分かかる場合があります。MLflowエクスペリメントで進行状況を監視できます。
詳細については、 GitHubのLLM -Foundry を参照してください。
from serverless_gpu import distributed
from llmfoundry.command_utils.train import train
from omegaconf import DictConfig
import mlflow
import torch
@distributed(gpus=1, gpu_type='a10', remote=True)
def run_training():
mlflow.end_run()
trainer = train(cfg)
run = mlflow.active_run()
mlflow_run_id = run.info.run_id
run_name = trainer.state.run_name
del trainer
mlflow.end_run()
return mlflow_run_id, run_name
# Run training
result = run_training.distributed()
mlflow_run_id, run_name = result[0]
print(f"Run ID: {mlflow_run_id}")
# Download and load checkpoint
checkpoint_artifact_path = f"{run_name}/checkpoints/ep0-ba200-rank0.pt"
print(f"Downloading: {checkpoint_artifact_path}")
local_path = mlflow.artifacts.download_artifacts(
artifact_uri=f"runs:/{mlflow_run_id}/{checkpoint_artifact_path}"
)
print(f"Downloaded to: {local_path}")
ckpt = torch.load(local_path, map_location="cpu", weights_only=False)
print("\nTop-level keys:")
print(list(ckpt.keys()))
if "state" in ckpt:
print("\nState keys:")
print(list(ckpt["state"].keys()))
if "model" in ckpt["state"]:
print(f"\nModel state dict: {len(ckpt['state']['model'])} keys")
for k in list(ckpt["state"]["model"].keys())[:10]:
print(f" {k}")
トレーニング結果を確認し、モデルを提供する
トレーニングが完了したら、結果を確認し、微調整したモデルをデプロイします。
トレーニング メトリクスを確認します。
-
で指定されたMLflowエクスペリメントに移動します。
experiment_name- ワークスペース UI のエクスペリメントページでもエクスペリメントを見つけることができます。
-
表示するトレーニング実行を選択してください:
- 「メトリクス」タブのメトリクス の トレーニングと評価
- 問題 タブのモデル問題
- 「 アーティファクト」タブのチェックポイントとアーティファクト
-
モデル詳細 タブには、 Unity Catalogに登録されたモデルが表示されます。
モデルを提供します:
- 指定されたパスに登録されたモデルに移動します。
register_to - 最新のモデルバージョンを選択してください
- プロビジョニングされたスループットを使用してデプロイするには、 このモデルを提供するを クリックします
- 希望するスループット設定でサービスエンドポイントを構成する
次のステップ
対照学習を使用して埋め込みモデルを微調整したので、さらに詳しく知るには次のリソースを参照してください。
- サーバレス GPU コンピュート- サーバレス GPU の特徴と機能について学びます
- サーバレス GPU コンピュートのベスト プラクティス- GPU ワークロードを最適化します。
- 基盤モデルAPIs - プロビジョニング スループットを使用したモデルのデプロイと提供
- LLM-Foundryドキュメント- 高度なトレーニング機能と構成を調べる
- StreamingDataset ドキュメント- MDS 形式と最適化の詳細
- Unity Catalogモデルレジストリ- モデルの管理とバージョン管理