メインコンテンツまでスキップ

対照学習を使用して埋め込みモデルを微調整する

このノートブックでは、対照学習を使用して、サーバレス GPU コンピュート上で BERT スタイルの埋め込みモデルを微調整する方法を示します。 gte-large-en-v1.5モデルを使用し、MosaicML Composer トレーナーを使用して単一の A10G GPU でトレーニングして、チェックポイントを保存し、トレーニングを再開し、結果をMLflowに記録します。

埋め込みモデルは、ベクター データベースや検索拡張生成 (RAG) アプリケーションで広く使用されています。カスタム データに埋め込みモデルをファインチューニングすることは、特定のドメインの検索精度を向上させる強力な方法です。

このノートブックでは、次の方法を学習します。

  • 依存関係をインストールして環境を構成する
  • Delta テーブルからトレーニング データをロードする
  • データをモザイクストリーミングデータセット(MDS)形式に変換する
  • モデル、オプティマイザー、トレーニングを構成する
  • バッチ内ネガティブデータを用いた対照学習を使用してモデルをトレーニングする
  • MLflowと登録するモデルの体験をUnity Catalogに追跡
  • パフォーマンスを確認し、微調整されたモデルを提供する

この例では、 MS Marcoデータセットの前処理済みバージョンを使用していますが、独自のデータで動作するように適応させることができます。

対照学習の概要

対照学習で は、対照損失を使用して、類似のインスタンスが潜在空間内でより接近しているデータ表現を学習します。埋め込みモデルの場合、これは、「アフェンピンシャーとは何か」と「かわいいダックスフント」のようなクエリを、「アフェンピンシャーとは何か」と「dbsql の使い方は?」のようなクエリよりも類似したものとして扱うことを意味します。モデルは クエリ を肯定的な文章 (関連するテキスト) と 否定的な文章 (無関係なテキスト) と比較して、意味の違いを学習します。

否定的な文章を選択する 2 つのアプローチ:

  • バッチ内ネガティブ : ネガティブなパッセージはバッチ内からランダムに選択されます。特定のクエリとパッセージのペアの場合、バッチ内の他のすべてのパッセージは負の例になります。バッチ サイズが 8 の場合、クエリごとに 7 つの否定的なパッセージと 1 つの肯定的なパッセージが取得されます。 バッチ サイズが大きいほど、負の例が多くなり 、このアプローチがより効果的になります。

  • ハードネガティブ : 意味的に難しい事前定義されたネガティブな文章。クエリに関連している可能性はありますが、少し間違っているか無関係です。llm-foundryコードは、より高度なファインチューニングのためのハード ネガをサポートしています。

このノートブックはバッチネガを使用しています。 データに否定的な文章が提供されていない場合、 llm-foundry コードは、他のクエリからの肯定的な文章を否定的な文章として扱うことで、自動的にそれらを推測します。

要件

このノートブックを実行する前に、いくつかの設定を行って、サーバレス GPU コンピュートに接続する必要があります。

ノートブックを構成する

このノートブックはクエリー (ウィジェット) を使用してパスと設定を構成します。 実行する前に、次の点を更新してください。

  • catalog: Unity Catalog カタログ名 (例: main )
  • schema: Unity Catalogスキーマ名
  • train_delta_table: Deltaテーブル名 (カタログ/スキーマ接頭辞なし)
  • val_delta_table: 検証Deltaテーブル名(オプション)
  • uc_checkpoint_folder: チェックポイントのUnity Catalogボリュームフォルダ名
  • register_to: Unity Catalog登録用のモデル名
  • experiment_name: MLflow体験パス (形式: /Users/<username>/<run_name> )

サーバレスGPUコンピュートに接続する

このノートブックには 1 つの A10G GPU が必要です。

  1. ノートブックの上部にある [接続 ] ドロップダウンをクリックします。
  2. サーバレス GPU を選択します。
  3. ノートブックの右側にある 環境 サイドパネルを開きます。
  4. このデモでは、 アクセラレータA10 に設定します。
  5. [適用] を選択し、 [確認] をクリックして、この環境をノートブックに適用します。

詳細については、 「サーバレス GPU コンピュート」を参照してください。

依存関係をインストールする

まず、必要なライブラリをすべてインストールし、環境の準備ができていることを確認します。

Python
%pip install llm-foundry[gpu]==0.20.0
%pip uninstall flash_attn -y
%pip install transformers==4.46.0
%pip install hf_transfer
%restart_python

環境変数を設定する

分散トレーニングと一時ファイルストレージの環境変数を構成します。

Python
import os
import tempfile
import mlflow

os.environ["TMPDIR"] = os.path.join(os.getcwd(), tempfile.mkdtemp())
os.environ["NCCL_DEBUG"] = "WARN"
os.environ["WORLD_SIZE"] = "1"

設定ウィジェットを作成する

ノートブックの入力ウィジェットを作成します。 続行する前に、ノートブックの上部にあるウィジェット パネルにこれらの値を入力します。

Python
# Create widgets for configuration
dbutils.widgets.text(
"train_delta_table", "ms_marco_v_1_1_train_processed", "Training Delta Table Name"
)
dbutils.widgets.text(
"val_delta_table", "ms_marco_v_1_1_val_processed", "Validation Delta Table Name"
)
dbutils.widgets.text("register_to", "sgc_ft_embedding", "Model Registry Path")
dbutils.widgets.text("experiment_name", "/Users/<EMAIL>/Embedding_finetuning", "MLflow Experiment Name")
dbutils.widgets.text("uc_checkpoint_folder", "checkpoints", "UC Checkpoint Folder")
dbutils.widgets.text("catalog", "main", "catalog")
dbutils.widgets.text("schema", "default", "schema")
Python
# Validate widget inputs
assert dbutils.widgets.get("train_delta_table")
assert dbutils.widgets.get("register_to")
assert dbutils.widgets.get("experiment_name")
assert dbutils.widgets.get("catalog")
assert dbutils.widgets.get("schema")
assert dbutils.widgets.get("uc_checkpoint_folder")
Python
# Build env paths
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
train_delta_table = dbutils.widgets.get("train_delta_table")
val_delta_table = dbutils.widgets.get("val_delta_table")
uc_checkpoint_folder = dbutils.widgets.get("uc_checkpoint_folder")
register_to = dbutils.widgets.get("register_to")
experiment_name = dbutils.widgets.get("experiment_name")

train_delta_table = f"{catalog}.{schema}.{train_delta_table}"
val_delta_table = f"{catalog}.{schema}.{val_delta_table}" if val_delta_table else None
uc_checkpoint_path = f"{catalog}.{schema}.{uc_checkpoint_folder}"
register_to_path = f"{catalog}.{schema}.{register_to}"

Delta テーブルからトレーニング データをロードする

Delta テーブルからトレーニング データと検証データを読み込みます。データには次の列が必要です。

  • query_text: クエリまたは質問のテキスト
  • positive_passage: クエリに関連するテキスト部分
  • negative_passages (オプション): 事前に定義されたハードネガティブ文章の配列

この例では、 MS Marcoデータセットの前処理済みバージョンを使用します。negative_passagesを指定しない場合、トレーニング コードはバッチ内のネガティブを自動的に使用します。

Python
df_train = spark.table(train_delta_table)
df_val = None

if val_delta_table:
df_val = spark.table(val_delta_table)

MODEL_REGISTRY_PREFIX = f"{catalog}.{schema}"
REGISTERED_MODEL_NAME = register_to

EXPERIMENT_NAME = experiment_name
UC_CHECKPOINT_PATH = f"{catalog}.{schema}.{uc_checkpoint_folder}"

Databricks の資格情報を構成する

トレーニング中に Unity Catalog および MLflow にアクセスするための Databricks CLI 資格情報を設定します。この構成ファイルにより、トレーニング プロセスが Databricks サービスで認証できるようになります。

Python
%sh echo -e "[DEFAULT]\nhost=$DATABRICKS_HOST\ntoken=$DATABRICKS_TOKEN" > ~/.databrickscfg

データをモザイクストリーミングデータセット形式に変換する

Deltaテーブルをモザイク ストリーミング データセット (MDS) 形式に変換します。これは、サーバレス GPU コンピュートでの分散トレーニング用に最適化されています。 MDS は以下を提供します:

  • トレーニング中のデータ読み込みの高速化
  • 効率的な圧縮と保存
  • Composerトレーナーとのシームレスな統合

変換機能は必要なスキーマ変換を処理し、MDS ファイルをUnity Catalogボリュームに保存します。 データに異なる列名がある場合は、それに応じてconvert_x関数を更新してください。

詳細については、 StreamingDataset のドキュメントを参照してください。

Python
import os
import gc
from streaming import MDSWriter, StreamingDataset
from pyspark.sql import DataFrame
import json
import warnings
warnings.filterwarnings("ignore", module="threadpoolctl")


def process_embedding_data(
df: DataFrame,
output_path: str,
compression: str,
hashes: list[str],
limit: str,
):
def convert_x(x: dict) -> dict:
return {
"query_text": x["query_text"],
"positive_passage": x["positive_passage"],
"negative_passages": (
json.dumps(x["negative_passages"])
if x.get("negative_passages") is not None
else "[]"
),
}

try:
dtypes = {
"query_text": "str",
"positive_passage": "str",
"negative_passages": "str",
}

print(f"Starting conversion to MDS at {output_path}")

# Clear memory before processing
gc.collect()

row_count = 0
with MDSWriter(
out=output_path,
columns=dtypes,
compression=compression,
hashes=hashes,
size_limit=limit,
) as out:
for row in df.toLocalIterator():
record = convert_x(row.asDict())
out.write(record)
row_count += 1
if row_count % 10000 == 0:
print(f"Processed {row_count} records...")

print(f"Successfully wrote {row_count} records to {output_path}")

except Exception as e:
print(f"Error during data conversion: {e}")
raise e
Python
compression = "zstd:7"
hashes = ["sha1"]
limit = "10mb"

import os

# FUSE path for checking existence
uc_train_folder = f"/Volumes/{catalog}/{schema}/embedding_temp_data/train"
uc_val_folder = f"/Volumes/{catalog}/{schema}/embedding_temp_data/val"

# dbfs path for MDSWriter and StreamingDataset
train_folder = f"dbfs:/Volumes/{catalog}/{schema}/embedding_temp_data/train"
val_folder = f"dbfs:/Volumes/{catalog}/{schema}/embedding_temp_data/val"

train_index = os.path.join(uc_train_folder, "index.json")
val_index = os.path.join(uc_val_folder, "index.json")

if os.path.exists(train_index):
print("Train MDS data already exists, skipping conversion.")
else:
process_embedding_data(df_train, train_folder, compression, hashes, limit)

if df_val is not None:
if os.path.exists(val_index):
print("Validation MDS data already exists, skipping conversion.")
else:
process_embedding_data(df_val, val_folder, compression, hashes, limit)

埋め込みモデルを構成する

ファインチューニング用のモデルとトークナイザー構成を定義します。 この例では、Hugging Face のgte-large-en-v1.5モデルを使用します。

主要な構成:

  • temperature: 対照損失における類似度スコアをスケーリングするハイパーパラメータ (0-1)。損失値が極端に高いか低い場合はこれを調整します(デフォルト: 0.5)

  • pos_step_size: 負のサンプリングの位置ステップ サイズ。バッチ内のネガティブの場合は2に設定し、事前定義されたネガティブを使用する場合は1 + number of hard negativesに設定します。

  • vector_representation: 埋め込みの表現方法

    • avg: 平均トークン埋め込み(ほとんどのモデルに推奨)
    • eos: シーケンス終了トークン埋め込みを使用する
  • gather_in_batch_negatives: バッチ内ネガティブの場合はtrueに設定し、事前定義されたハードネガティブの場合はfalseに設定します

  • pretrained_model_name_or_path: Hugging Faceモデル識別子

トークナイザー構成では、モデルの最大シーケンス長と特殊トークンを設定します。

Python
model_cfg = {
"name": "finetune_embedding_model",
"trust_remote_code": True,
"contrastive_config": {
"temperature": 0.5,
"pos_step_size": 2, # set to 2 when not using predefined hard negatives. Otherwise use 1 + number of hard negatives
"normalize_output": True,
"vector_representation": "avg", # or eos, depending on the model default
"gather_in_batch_negatives": True, # set to true when not using predefined hard negatives
},
"pretrained_model_name_or_path": "Alibaba-NLP/gte-large-en-v1.5",
"loss_fn": "torch_crossentropy"
}

tokenizer_cfg = {
"name": "Alibaba-NLP/gte-large-en-v1.5",
"kwargs": {
"eos_token": "</s>", # this is the standard eos token for gte-large-en-v1.5
"model_max_length": 128,
"trust_remote_code": True,
},
}

MLflow ログを構成する

MLflowログを設定して、トレーニング メトリクス、論点、アーティファクトを追跡します。 ロガー構成では以下を指定します。

  • 実行が記録されるMLflowエクスペリメント
  • モデルレジストリの保存先としての Unity Catalog
  • モデル登録用のカタログとスキーマのプレフィックス
Python
logger_cfg = {
"mlflow": {
"run_name": "finetune_embedding",
"tracking_uri": "databricks",
"experiment_name": EXPERIMENT_NAME,
"model_registry_uri": "databricks-uc",
"model_registry_prefix": MODEL_REGISTRY_PREFIX,
}
}

トレーニングコールバックを構成する

コールバックはトレーニング プロセスのさまざまな側面を制御します。最も重要なコールバックはhf_checkpointerです。

  • Hugging Face互換のチェックポイントをUnity Catalogに保存します
  • 提供するモデルをUnity Catalogに登録する
  • プロビジョニングされたスループットサービングのモデルメタデータを構成する
  • 定期的にチェックポイントを保存します(この例では1時間ごと)

その他のコールバックは、学習率やメモリ使用量を監視し、ガベージ コレクションを実行して GPU メモリを最適化します。

Python
callback_cfg = {
"lr_monitor": {},
"scheduled_gc": {"batch_interval": 1000},
"memory_monitor": {},
"hf_checkpointer": {
"precision": "bfloat16",
"save_folder": UC_CHECKPOINT_PATH,
"save_interval": "1h",
"mlflow_logging_config": {
"task": "llm/v1/embeddings",
"metadata": {
"task": "llm/v1/embeddings",
"source": "huggingface",
"pretrained_model_name": "Alibaba-NLP/gte-large-en-v1.5",
"databricks_model_family": "NewModel (gte_v1_5)",
"databricks_model_size_parameters": "434m",
},
},
"mlflow_registered_model_name": REGISTERED_MODEL_NAME,
},
}

トレーニングハイパーパラメータを設定する

オプティマイザー、学習率スケジューラー、精度、およびトレーニング アルゴリズムを定義します。これらは、データと要件に基づいて調整できる標準的な機械学習トレーニングの問題です。

  • 最適化装置 : 分離重み減衰を備えたAdamW
  • 学習率 : コサインウォームアップスケジュールで3e-5
  • 精度 : より高速なトレーニングのために bfloat16 を使用した自動混合精度
  • 勾配クリッピング : トレーニング中に勾配が爆発するのを防ぐ
Python
optimizer_cfg = {
"lr": 0.00003,
"eps": 1.0e-08,
"name": "decoupled_adamw",
"betas": [0.9, 0.95],
"weight_decay": 0.0001,
}

precision_cfg = "amp_bf16"

scheduler_cfg = {"name": "cosine_with_warmup", "alpha_f": 0.02, "t_warmup": "0.06dur"}

algorithms_cfg = {
"gradient_clipping": {"clipping_type": "norm", "clipping_threshold": 1}
}

データローダーを構成する

トレーニング中にトレーニング データと評価データをロードする方法を定義します。データローダー:

  • Unity Catalogボリューム内のMDS形式のデータを指す
  • テキストの前処理を設定する(「query:」および「passage:」プレフィックスを先頭に追加する)
  • 最大シーケンス長とバッチ処理を設定する
  • トレーニングの収束性を向上させるためにシャッフルを有効にする

remoteパスが変換された MDS データの場所を指していることを確認します。

Python
train_loader = {
"name": "contrastive_pairs",
"dataset": {
"local": None,
"split": None,
"remote": train_folder,
"shuffle": True,
"max_seq_len": 128,
"shuffle_seed": 42,
"prepend_query": "query: ",
"prepend_passage": "passage: ",
"append_eos_token": True,
},
"drop_last": True,
"num_workers": 8,
}

eval_loader = {
"name": "contrastive_pairs",
"dataset": {
"local": None,
"split": None,
"remote": val_folder,
"shuffle": True,
"max_seq_len": 128,
"shuffle_seed": 42,
"prepend_query": "query: ",
"prepend_passage": "passage: ",
"append_eos_token": True,
},
"drop_last": True,
"num_workers": 8,
}

完全なトレーニング構成を組み立てる

すべての構成コンポーネントを 1 つのトレーニング構成オブジェクトに結合します。これには、モデル、トークナイザー、データ ローダー、オプティマイザー、コールバック、トレーニング パラメーターが含まれます。

Python
from omegaconf import DictConfig, OmegaConf

config = {
"seed": 42,
"max_seq_len": 128,

"model": model_cfg,
"tokenizer": tokenizer_cfg,

"loggers": logger_cfg,
"callbacks": callback_cfg,

"run_name": "finetune-BERT",

"optimizer": optimizer_cfg,
"precision": precision_cfg,
"scheduler": scheduler_cfg,
"algorithms": algorithms_cfg,

"train_loader": train_loader,
"eval_loader": eval_loader,

"eval_first": True,
"save_folder": 'dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}/artifacts/{run_name}/checkpoints',
"max_duration": "200ba",
"progress_bar": False,

"eval_interval": "50ba",
"save_interval": "50ba",
"log_to_console": True,
"load_weights_only": True,
"console_log_interval": "1ba",
"device_eval_batch_size": 1,
"eval_subset_num_batches": 4,
"global_train_batch_size": 1,
"device_train_microbatch_size": 1
}
cfg = DictConfig(config)

埋め込みモデルをトレーニングする

埋め込みモデルに最適化されたトレーニングを提供する MosaicML LLM-Foundry ライブラリを使用してトレーニング プロセスを実行します。トレーニング機能:

  • @distributedデコレータを使用して GPU リソースをプロビジョニングします
  • 50バッチごとの評価付きで200バッチトレーニングする
  • チェックポイントをUnity Catalogに保存します
  • メトリクスとアーティファクトをMLflowに記録します
  • 最終モデルをUnity Catalogに登録する

データセットのサイズによっては、トレーニングに数分かかる場合があります。MLflowエクスペリメントで進行状況を監視できます。

詳細については、 GitHubのLLM -Foundry を参照してください。

Python
from serverless_gpu import distributed
from llmfoundry.command_utils.train import train
from omegaconf import DictConfig
import mlflow
import torch

@distributed(gpus=1, gpu_type='a10', remote=True)
def run_training():
mlflow.end_run()
trainer = train(cfg)

run = mlflow.active_run()
mlflow_run_id = run.info.run_id
run_name = trainer.state.run_name
del trainer
mlflow.end_run()

return mlflow_run_id, run_name

# Run training
result = run_training.distributed()
mlflow_run_id, run_name = result[0]

print(f"Run ID: {mlflow_run_id}")

# Download and load checkpoint
checkpoint_artifact_path = f"{run_name}/checkpoints/ep0-ba200-rank0.pt"
print(f"Downloading: {checkpoint_artifact_path}")
local_path = mlflow.artifacts.download_artifacts(
artifact_uri=f"runs:/{mlflow_run_id}/{checkpoint_artifact_path}"
)
print(f"Downloaded to: {local_path}")

ckpt = torch.load(local_path, map_location="cpu", weights_only=False)

print("\nTop-level keys:")
print(list(ckpt.keys()))

if "state" in ckpt:
print("\nState keys:")
print(list(ckpt["state"].keys()))

if "model" in ckpt["state"]:
print(f"\nModel state dict: {len(ckpt['state']['model'])} keys")
for k in list(ckpt["state"]["model"].keys())[:10]:
print(f" {k}")

トレーニング結果を確認し、モデルを提供する

トレーニングが完了したら、結果を確認し、微調整したモデルをデプロイします。

トレーニング メトリクスを確認します。

  1. で指定されたMLflowエクスペリメントに移動します。 experiment_name

    • ワークスペース UI のエクスペリメントページでもエクスペリメントを見つけることができます。
  2. 表示するトレーニング実行を選択してください:

    • 「メトリクス」タブのメトリクス トレーニングと評価
    • 問題 タブのモデル問題
    • アーティファクト」タブのチェックポイントとアーティファクト
  3. モデル詳細 タブには、 Unity Catalogに登録されたモデルが表示されます。

モデルを提供します:

  1. 指定されたパスに登録されたモデルに移動します。 register_to
  2. 最新のモデルバージョンを選択してください
  3. プロビジョニングされたスループットを使用してデプロイするには、 このモデルを提供するを クリックします
  4. 希望するスループット設定でサービスエンドポイントを構成する

次のステップ

対照学習を使用して埋め込みモデルを微調整したので、さらに詳しく知るには次のリソースを参照してください。

サンプルノートブック

対照学習を使用して埋め込みモデルを微調整する

ノートブックを新しいタブで開く