mlflow.pytorch

The mlflow.pytorch module provides an API for logging and loading PyTorch models. This module exports PyTorch models with the following flavors:

PyTorch (native) format

This is the main flavor that can be loaded back into PyTorch.

mlflow.pyfunc

Produced for use by generic pyfunc-based deployment tools and batch inference.

mlflow.pytorch.get_default_conda_env()[source]
Returns

The default Conda environment for MLflow Models produced by calls to save_model() and log_model().

mlflow.pytorch.load_model(model_uri, **kwargs)[source]

Load a PyTorch model from a local file or a run.

Parameters
  • model_uri

    The location, in URI format, of the MLflow model, for example:

    • /Users/me/path/to/local/model

    • relative/path/to/local/model

    • s3://my_bucket/path/to/model

    • runs:/<mlflow_run_id>/run-relative/path/to/model

    • models:/<model_name>/<model_version>

    • models:/<model_name>/<stage>

    For more information about supported URI schemes, see Referencing Artifacts.

  • kwargs – kwargs to pass to torch.load method.

Returns

A PyTorch model.

Example
import torch
import mlflow
import mlflow.pytorch
# Set values
model_path_dir = ...
run_id = "96771d893a5e46159d9f3b49bf9013e2"
pytorch_model = mlflow.pytorch.load_model("runs:/" + run_id + "/" + model_path_dir)
y_pred = pytorch_model(x_new_data)
mlflow.pytorch.log_model(pytorch_model, artifact_path, conda_env=None, code_paths=None, pickle_module=None, registered_model_name=None, **kwargs)[source]

Log a PyTorch model as an MLflow artifact for the current run.

Parameters
  • pytorch_model

    PyTorch model to be saved. Must accept a single torch.FloatTensor as input and produce a single output tensor. Any code dependencies of the model’s class, including the class definition itself, should be included in one of the following locations:

    • The package(s) listed in the model’s Conda environment, specified by the conda_env parameter.

    • One or more of the files specified by the code_paths parameter.

  • artifact_path – Run-relative artifact path.

  • conda_env

    Path to a Conda environment file. If provided, this decsribes the environment this model should be run in. At minimum, it should specify the dependencies contained in get_default_conda_env(). If None, the default get_default_conda_env() environment is added to the model. The following is an example dictionary representation of a Conda environment:

    {
        'name': 'mlflow-env',
        'channels': ['defaults'],
        'dependencies': [
            'python=3.7.0',
            'pytorch=0.4.1',
            'torchvision=0.2.1'
        ]
    }
    

  • code_paths – A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are prepended to the system path when the model is loaded.

  • pickle_module – The module that PyTorch should use to serialize (“pickle”) the specified pytorch_model. This is passed as the pickle_module parameter to torch.save(). By default, this module is also used to deserialize (“unpickle”) the PyTorch model at load time.

  • registered_model_name – Note:: Experimental: This argument may change or be removed in a future release without warning. If given, create a model version under registered_model_name, also creating a registered model if one with the given name does not exist.

  • kwargs – kwargs to pass to torch.save method.

Example
import torch
import mlflow
import mlflow.pytorch
# X data
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
# Y data with its expected value: labels
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
# Partial Model example modified from Sung Kim
# https://github.com/hunkim/PyTorchZeroToAll
class Model(torch.nn.Module):
    def __init__(self):
       super(Model, self).__init__()
       self.linear = torch.nn.Linear(1, 1)  # One in and one out
    def forward(self, x):
        y_pred = self.linear(x)
    return y_pred
# our model
model = Model()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Training loop
for epoch in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x_data)
    # Compute and print loss
    loss = criterion(y_pred, y_data)
    print(epoch, loss.data.item())
    #Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# After training
for hv in [4.0, 5.0, 6.0]:
    hour_var = torch.Tensor([[hv]])
    y_pred = model(hour_var)
    print("predict (after training)",  hv, model(hour_var).data[0][0])
# log the model
with mlflow.start_run() as run:
    mlflow.log_param("epochs", 500)
    mlflow.pytorch.log_model(model, "models")
mlflow.pytorch.save_model(pytorch_model, path, conda_env=None, mlflow_model=<mlflow.models.Model object>, code_paths=None, pickle_module=None, **kwargs)[source]

Save a PyTorch model to a path on the local file system.

Parameters
  • pytorch_model

    PyTorch model to be saved. Must accept a single torch.FloatTensor as input and produce a single output tensor. Any code dependencies of the model’s class, including the class definition itself, should be included in one of the following locations:

    • The package(s) listed in the model’s Conda environment, specified by the conda_env parameter.

    • One or more of the files specified by the code_paths parameter.

  • path – Local path where the model is to be saved.

  • conda_env

    Either a dictionary representation of a Conda environment or the path to a Conda environment yaml file. If provided, this decsribes the environment this model should be run in. At minimum, it should specify the dependencies contained in get_default_conda_env(). If None, the default get_default_conda_env() environment is added to the model. The following is an example dictionary representation of a Conda environment:

    {
        'name': 'mlflow-env',
        'channels': ['defaults'],
        'dependencies': [
            'python=3.7.0',
            'pytorch=0.4.1',
            'torchvision=0.2.1'
        ]
    }
    

  • mlflow_modelmlflow.models.Model this flavor is being added to.

  • code_paths – A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are prepended to the system path when the model is loaded.

  • pickle_module – The module that PyTorch should use to serialize (“pickle”) the specified pytorch_model. This is passed as the pickle_module parameter to torch.save(). By default, this module is also used to deserialize (“unpickle”) the PyTorch model at load time.

  • kwargs – kwargs to pass to torch.save method.

Example
import torch
import mlflow
import mlflow.pytorch
# Create model and set values
pytorch_model = Model()
pytorch_model_path = ...
# train our model
for epoch in range(500):
    y_pred = pytorch_model(x_data)
    ...
# Save the model
with mlflow.start_run() as run:
    mlflow.log_param("epochs", 500)
    mlflow.pytorch.save_model(pytorch_model, pytorch_model_path)