mlflow.sklearn
The mlflow.sklearn
module provides an API for logging and loading scikit-learn models. This
module exports scikit-learn models with the following flavors:
- Python (native) pickle format
This is the main flavor that can be loaded back into scikit-learn.
mlflow.pyfunc
Produced for use by generic pyfunc-based deployment tools and batch inference.
-
mlflow.sklearn.
get_default_conda_env
(include_cloudpickle=False)[source] - Returns
The default Conda environment for MLflow Models produced by calls to
save_model()
andlog_model()
.
-
mlflow.sklearn.
load_model
(model_uri)[source] Load a scikit-learn model from a local file or a run.
- Parameters
model_uri –
The location, in URI format, of the MLflow model, for example:
/Users/me/path/to/local/model
relative/path/to/local/model
s3://my_bucket/path/to/model
runs:/<mlflow_run_id>/run-relative/path/to/model
models:/<model_name>/<model_version>
models:/<model_name>/<stage>
For more information about supported URI schemes, see Referencing Artifacts.
- Returns
A scikit-learn model.
-
mlflow.sklearn.
log_model
(sk_model, artifact_path, conda_env=None, serialization_format='cloudpickle', registered_model_name=None)[source] Log a scikit-learn model as an MLflow artifact for the current run.
- Parameters
sk_model – scikit-learn model to be saved.
artifact_path – Run-relative artifact path.
conda_env –
Either a dictionary representation of a Conda environment or the path to a Conda environment yaml file. If provided, this decsribes the environment this model should be run in. At minimum, it should specify the dependencies contained in
get_default_conda_env()
. If None, the defaultget_default_conda_env()
environment is added to the model. The following is an example dictionary representation of a Conda environment:{ 'name': 'mlflow-env', 'channels': ['defaults'], 'dependencies': [ 'python=3.7.0', 'scikit-learn=0.19.2' ] }
serialization_format – The format in which to serialize the model. This should be one of the formats listed in
mlflow.sklearn.SUPPORTED_SERIALIZATION_FORMATS
. The Cloudpickle format,mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE
, provides better cross-system compatibility by identifying and packaging code dependencies with the serialized model.registered_model_name – Note:: Experimental: This argument may change or be removed in a future release without warning. If given, create a model version under
registered_model_name
, also creating a registered model if one with the given name does not exist.
import mlflow import mlflow.sklearn from sklearn.datasets import load_iris from sklearn import tree iris = load_iris() sk_model = tree.DecisionTreeClassifier() sk_model = sk_model.fit(iris.data, iris.target) # set the artifact_path to location where experiment artifacts will be saved #log model params mlflow.log_param("criterion", sk_model.criterion) mlflow.log_param("splitter", sk_model.splitter) # log model mlflow.sklearn.log_model(sk_model, "sk_models")
-
mlflow.sklearn.
save_model
(sk_model, path, conda_env=None, mlflow_model=<mlflow.models.Model object>, serialization_format='cloudpickle')[source] Save a scikit-learn model to a path on the local file system.
- Parameters
sk_model – scikit-learn model to be saved.
path – Local path where the model is to be saved.
conda_env –
Either a dictionary representation of a Conda environment or the path to a Conda environment yaml file. If provided, this decsribes the environment this model should be run in. At minimum, it should specify the dependencies contained in
get_default_conda_env()
. If None, the defaultget_default_conda_env()
environment is added to the model. The following is an example dictionary representation of a Conda environment:{ 'name': 'mlflow-env', 'channels': ['defaults'], 'dependencies': [ 'python=3.7.0', 'scikit-learn=0.19.2' ] }
mlflow_model –
mlflow.models.Model
this flavor is being added to.serialization_format – The format in which to serialize the model. This should be one of the formats listed in
mlflow.sklearn.SUPPORTED_SERIALIZATION_FORMATS
. The Cloudpickle format,mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE
, provides better cross-system compatibility by identifying and packaging code dependencies with the serialized model.
import mlflow.sklearn from sklearn.datasets import load_iris from sklearn import tree iris = load_iris() sk_model = tree.DecisionTreeClassifier() sk_model = sk_model.fit(iris.data, iris.target) # Save the model in cloudpickle format # set path to location for persistence sk_path_dir_1 = ... mlflow.sklearn.save_model( sk_model, sk_path_dir_1, serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE) # save the model in pickle format # set path to location for persistence sk_path_dir_2 = ... mlflow.sklearn.save_model(sk_model, sk_path_dir_2, serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE)