AutoML Python API を使用して予測モデルをトレーニングする
このサンプルノートブックでは、AutoML Python APIを使用してDatabricks上で時系列予測モデルをトレーニングする方法を示します。COVID-19の症例数データセットを使用して、30日間の日次予測期間でautoml.forecast()を呼び出し、将来の症例数を予測し、次にMLflowで最適なモデルをロードして予測を生成およびプロットします。
要件
Databricks Runtime for Machine Learning 10.0以降。
モデルの予測を保存するには、Databricks Runtime for Machine Learning 10.5以降が必要です。
COVID-19 データセット
このデータセットには、米国におけるCOVID-19ウイルスの感染者数を日付別に記録したデータと、地理情報が含まれています。目標は、今後30日間に米国で発生するウイルス感染者数を予測することである。
import pyspark.pandas as ps
df = ps.read_csv("/databricks-datasets/COVID/covid-19-data")
df["date"] = ps.to_datetime(df['date'], errors='coerce')
df["cases"] = df["cases"].astype(int)
display(df)
AutoML
以下のコマンドはAutoMLの実行を開始します。target_col引数には、モデルが予測する対象となる列と時間列を指定する必要があります。実行が完了したら、最適な試用ノートブックへのリンクをクリックして、トレーニングコードを確認できます。
この例では、以下のことも指定されています。
horizon=30AutoMLが30日先まで予測を行うように指定する。frequency="d"各日ごとに予報を提供するよう指定する。primary_metric="mdape"トレーニング中に最適化するメトリクスを指定します。
automl.forecast() クラシックコンピュートでのみ利用可能です。
import databricks.automl
import logging
# Disable informational messages from fbprophet
logging.getLogger("py4j").setLevel(logging.WARNING)
# Note: If you are running Databricks Runtime for Machine Learning 10.4 or below, use this line instead:
# summary = databricks.automl.forecast(df, target_col="cases", time_col="date", horizon=30, frequency="d", primary_metric="mdape")
summary = databricks.automl.forecast(df, target_col="cases", time_col="date", horizon=30, frequency="d", primary_metric="mdape", output_database="default")
モデルを反復処理する
- 上にリンクされているノートブックと体験を探索してください。
- 最良のトライアルノートブックのメトリクスが良好であれば、次のセルに進むことができます。
- 最良の試行で生成されたモデルを改善したい場合は、次の手順に従ってください。
- 最も良い試行結果が得られたノートブックを開き、それを複製してください。
- モデルを改善するために、必要に応じてノートブックを編集してください。
- モデルに満足したら、トレーニングされたモデルのアーティファクトが記録されている URI をメモします。 次のセルの変数
model_uriにこのURIを代入してください。
最良モデルによる予測結果を表示する
注: このセクションを実行するには、Databricks Runtime for Machine Learning 10.5以降が必要です。
最適なモデルからの予測を読み込む
Databricks Runtime for Machine Learning 10.5以降では、 output_databaseが指定されている場合、AutoMLは最適なモデルからの予測を保存します。
# Load the saved predictions.
forecast_pd = spark.table(summary.output_table_name)
display(forecast_pd)
予測にはモデルを使用する
このセクションのコマンドは、Databricks Runtime for Machine Learning 10.0以降で使用できます。
MLflowでモデルをロードする
MLflow を使用すると、AutoML trial_idを使用してモデルを Python に簡単にインポートできます。
import mlflow.pyfunc
from mlflow.tracking import MlflowClient
run_id = MlflowClient()
trial_id = summary.best_trial.mlflow_run_id
model_uri = "runs:/{run_id}/model".format(run_id=trial_id)
pyfunc_model = mlflow.pyfunc.load_model(model_uri)
モデルを使用して予測を行う
予測を生成するには、 predict_timeseriesモデルメソッドを呼び出してください。
Databricks Runtime for Machine Learning 10.5以降では、 include_history=Falseを設定することで予測データのみを取得できます。
forecasts = pyfunc_model._model_impl.python_model.predict_timeseries()
display(forecasts)
# Option for Databricks Runtime for Machine Learning 10.5 or above
# forecasts = pyfunc_model._model_impl.python_model.predict_timeseries(include_history=False)
予測された地点をプロットする
下のプロットでは、黒い太い線が時系列データセットを示し、青い線がモデルによって作成された予測です。
df_true = df.groupby("date").agg(y=("cases", "avg")).reset_index().to_pandas()
import matplotlib.pyplot as plt
fig = plt.figure(facecolor='w', figsize=(10, 6))
ax = fig.add_subplot(111)
forecasts = pyfunc_model._model_impl.python_model.predict_timeseries(include_history=True)
fcst_t = forecasts['ds'].dt.to_pydatetime()
ax.plot(df_true['date'].dt.to_pydatetime(), df_true['y'], 'k.', label='Observed data points')
ax.plot(fcst_t, forecasts['yhat'], ls='-', c='#0072B2', label='Forecasts')
ax.fill_between(fcst_t, forecasts['yhat_lower'], forecasts['yhat_upper'],
color='#0072B2', alpha=0.2, label='Uncertainty interval')
ax.legend()
plt.show()
モデルを登録してデプロイする
AutoMLでトレーニングされたモデルは、他のモデルと同様にMLflow Model Registryに登録してデプロイできます。 「ログ、ロード、および登録するMLflowモデル」を参照してください。
トラブルシューティング: No module named pandas.core.indexes.numeric
AutoMLでトレーニングされたモデルをモデルサービングで提供する場合、エラーNo module named pandas.core.indexes.numericが表示されることがあります。これは、 AutoMLで使用されるpandasバージョンがモデルサービング エンドポイント環境のものと異なる場合に発生します。 解決するには:
- add-pandas-dependency.py スクリプトをダウンロードしてください。スクリプトは記録済みモデルの
requirements.txtとconda.yamlをピン留めpandas==1.5.3に編集します。 - スクリプトを編集して、モデルがログに記録されたMLflow実行の
run_idを含めるようにします。 - モデルを再登録します。
- 新モデル版を提供してください。