OSS 組み込みモデルを登録して提供する
このノートブックでは、ベクトル検索に使用できるモデルサービングエンドポイントに、オープンソースのテキスト埋め込みモデルe5-small-v2を設定します。
- モデルはHugging Faceハブからダウンロードしてください。
- MLflow Model Registryに登録します。
- モデルサービングエンドポイントを開始してモデルを提供します。
モデルe5-small-v2はhttps://huggingface.co/intfloat/e5-small-v2で入手可能です。
- MITライセンス
- バリエーション:
Databricks Runtimeに含まれるライブラリのバージョン一覧については、お使いのDatabricks Runtimeバージョンのリリースノートを参照してください。
Databricks Python SDKをインストールする
このノートブックは、Pythonクライアントを使用してサービス提供エンドポイントを操作します。
Python
%pip install -U databricks-sdk python-snappy
%pip install sentence-transformers
dbutils.library.restartPython()
モデルをダウンロード
Python
# Download model using the sentence_transformers library.
from sentence_transformers import SentenceTransformer
source_model_name = 'intfloat/e5-small-v2' # model name on Hugging Face Hub
model = SentenceTransformer(source_model_name)
Python
# Test the model, just to show it works.
sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)
print(embeddings)
モデルをMLflowに登録する
Python
import mlflow
mlflow.set_registry_uri("databricks-uc")
Python
# Specify the catalog and schema to use. You must have USE_CATALOG privilege on the catalog and USE_SCHEMA and CREATE_TABLE privileges on the schema.
# Change the catalog and schema here if necessary.
catalog = "main"
schema = "default"
model_name = "e5-small-v2"
Python
# MLflow model name. The Model Registry uses this name for the model.
registered_model_name = f"{catalog}.{schema}.{model_name}"
Python
# Compute input and output schema.
signature = mlflow.models.signature.infer_signature(sentences, embeddings)
print(signature)
Python
model_info = mlflow.sentence_transformers.log_model(
model,
artifact_path="model",
signature=signature,
input_example=sentences,
registered_model_name=registered_model_name)
Python
inference_test = ["I enjoy pies of both apple and cherry.", "I prefer cookies."]
# Load the custom model by providing the URI for where the model was logged.
loaded_model_pyfunc = mlflow.pyfunc.load_model(model_info.model_uri)
# Perform a quick test to ensure that the loaded model generates the correct output.
embeddings_test = loaded_model_pyfunc.predict(inference_test)
embeddings_test
Python
# Extract the version of the model you just registered.
mlflow_client = mlflow.MlflowClient()
def get_latest_model_version(model_name):
client = mlflow_client
model_version_infos = client.search_model_versions("name = '%s'" % model_name)
return max([int(model_version_info.version) for model_version_info in model_version_infos])
model_version = get_latest_model_version(registered_model_name)
model_version
モデルサービングエンドポイントの作成
詳細については、 「エンドポイントを提供する基盤モデルの作成」を参照してください。
注 :この例では、CPU を 0 まで縮小できる 小さな CPU エンドポイントを作成します。これは、迅速かつ小規模なテスト用です。より現実的なユースケースでは、埋め込み計算を高速化するためにGPUエンドポイントを使用することを検討してください。また、頻繁なクエリが予想される場合は、モデルサービングエンドポイントにはコールドスタート時のオーバーヘッドがあるため、スケールダウンを0にしないことも検討してください。
Python
endpoint_name = "e5-small-v2" # Name of endpoint to create
Python
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput
w = WorkspaceClient()
Python
endpoint_config_dict = {
"served_entities": [
{
"name": f'{registered_model_name.replace(".", "_")}_{1}',
"entity_name": registered_model_name,
"entity_version": model_version,
"workload_type": "CPU",
"workload_size": "Small",
"scale_to_zero_enabled": True,
}
]
}
endpoint_config = EndpointCoreConfigInput.from_dict(endpoint_config_dict)
# The endpoint may take several minutes to get ready.
w.serving_endpoints.create_and_wait(name=endpoint_name, config=endpoint_config)
クエリエンドポイント
上記のcreate_and_waitコマンドは、エンドポイントの準備が整うまで待機します。DatabricksのUIで、サービス提供エンドポイントのステータスを確認することもできます。
詳細については、 「クエリ プラットフォーム モデル」を参照してください。
Python
# Only run this command after the Model Serving endpoint is in the Ready state.
import time
start = time.time()
# If the endpoint is not yet ready, you might get a timeout error. If so, wait and then rerun the command.
endpoint_response = w.serving_endpoints.query(name=endpoint_name, dataframe_records=['Hello world', 'Good morning'])
end = time.time()
print(endpoint_response)
print(f'Time taken for querying endpoint in seconds: {end-start}')