mlflow.pytorch
The mlflow.pytorch
module provides an API for logging and loading PyTorch models. This module
exports PyTorch models with the following flavors:
- PyTorch (native) format
This is the main flavor that can be loaded back into PyTorch.
mlflow.pyfunc
Produced for use by generic pyfunc-based deployment tools and batch inference.
-
mlflow.pytorch.
get_default_conda_env
()[source] - Returns
The default Conda environment for MLflow Models produced by calls to
save_model()
andlog_model()
.
-
mlflow.pytorch.
load_model
(model_uri, **kwargs)[source] Load a PyTorch model from a local file or a run.
- Parameters
model_uri –
The location, in URI format, of the MLflow model, for example:
/Users/me/path/to/local/model
relative/path/to/local/model
s3://my_bucket/path/to/model
runs:/<mlflow_run_id>/run-relative/path/to/model
models:/<model_name>/<model_version>
models:/<model_name>/<stage>
For more information about supported URI schemes, see Referencing Artifacts.
kwargs – kwargs to pass to
torch.load
method.
- Returns
A PyTorch model.
-
mlflow.pytorch.
log_model
(pytorch_model, artifact_path, conda_env=None, code_paths=None, pickle_module=None, registered_model_name=None, **kwargs)[source] Log a PyTorch model as an MLflow artifact for the current run.
- Parameters
pytorch_model –
PyTorch model to be saved. Must accept a single
torch.FloatTensor
as input and produce a single output tensor. Any code dependencies of the model’s class, including the class definition itself, should be included in one of the following locations:The package(s) listed in the model’s Conda environment, specified by the
conda_env
parameter.One or more of the files specified by the
code_paths
parameter.
artifact_path – Run-relative artifact path.
conda_env –
Path to a Conda environment file. If provided, this decsribes the environment this model should be run in. At minimum, it should specify the dependencies contained in
get_default_conda_env()
. IfNone
, the defaultget_default_conda_env()
environment is added to the model. The following is an example dictionary representation of a Conda environment:{ 'name': 'mlflow-env', 'channels': ['defaults'], 'dependencies': [ 'python=3.7.0', 'pytorch=0.4.1', 'torchvision=0.2.1' ] }
code_paths – A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are prepended to the system path when the model is loaded.
pickle_module – The module that PyTorch should use to serialize (“pickle”) the specified
pytorch_model
. This is passed as thepickle_module
parameter totorch.save()
. By default, this module is also used to deserialize (“unpickle”) the PyTorch model at load time.registered_model_name – Note:: Experimental: This argument may change or be removed in a future release without warning. If given, create a model version under
registered_model_name
, also creating a registered model if one with the given name does not exist.kwargs – kwargs to pass to
torch.save
method.
import torch import mlflow import mlflow.pytorch # X data x_data = torch.Tensor([[1.0], [2.0], [3.0]]) # Y data with its expected value: labels y_data = torch.Tensor([[2.0], [4.0], [6.0]]) # Partial Model example modified from Sung Kim # https://github.com/hunkim/PyTorchZeroToAll class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.linear = torch.nn.Linear(1, 1) # One in and one out def forward(self, x): y_pred = self.linear(x) return y_pred # our model model = Model() criterion = torch.nn.MSELoss(size_average=False) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # Training loop for epoch in range(500): # Forward pass: Compute predicted y by passing x to the model y_pred = model(x_data) # Compute and print loss loss = criterion(y_pred, y_data) print(epoch, loss.data.item()) #Zero gradients, perform a backward pass, and update the weights. optimizer.zero_grad() loss.backward() optimizer.step() # After training for hv in [4.0, 5.0, 6.0]: hour_var = torch.Tensor([[hv]]) y_pred = model(hour_var) print("predict (after training)", hv, model(hour_var).data[0][0]) # log the model with mlflow.start_run() as run: mlflow.log_param("epochs", 500) mlflow.pytorch.log_model(model, "models")
-
mlflow.pytorch.
save_model
(pytorch_model, path, conda_env=None, mlflow_model=<mlflow.models.Model object>, code_paths=None, pickle_module=None, **kwargs)[source] Save a PyTorch model to a path on the local file system.
- Parameters
pytorch_model –
PyTorch model to be saved. Must accept a single
torch.FloatTensor
as input and produce a single output tensor. Any code dependencies of the model’s class, including the class definition itself, should be included in one of the following locations:The package(s) listed in the model’s Conda environment, specified by the
conda_env
parameter.One or more of the files specified by the
code_paths
parameter.
path – Local path where the model is to be saved.
conda_env –
Either a dictionary representation of a Conda environment or the path to a Conda environment yaml file. If provided, this decsribes the environment this model should be run in. At minimum, it should specify the dependencies contained in
get_default_conda_env()
. IfNone
, the defaultget_default_conda_env()
environment is added to the model. The following is an example dictionary representation of a Conda environment:{ 'name': 'mlflow-env', 'channels': ['defaults'], 'dependencies': [ 'python=3.7.0', 'pytorch=0.4.1', 'torchvision=0.2.1' ] }
mlflow_model –
mlflow.models.Model
this flavor is being added to.code_paths – A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are prepended to the system path when the model is loaded.
pickle_module – The module that PyTorch should use to serialize (“pickle”) the specified
pytorch_model
. This is passed as thepickle_module
parameter totorch.save()
. By default, this module is also used to deserialize (“unpickle”) the PyTorch model at load time.kwargs – kwargs to pass to
torch.save
method.
import torch import mlflow import mlflow.pytorch # Create model and set values pytorch_model = Model() pytorch_model_path = ... # train our model for epoch in range(500): y_pred = pytorch_model(x_data) ... # Save the model with mlflow.start_run() as run: mlflow.log_param("epochs", 500) mlflow.pytorch.save_model(pytorch_model, pytorch_model_path)