"""
The ``mlflow.lightgbm`` module provides an API for logging and loading LightGBM models.
This module exports LightGBM models with the following flavors:
LightGBM (native) format
This is the main flavor that can be loaded back into LightGBM.
:py:mod:`mlflow.pyfunc`
Produced for use by generic pyfunc-based deployment tools and batch inference.
.. _lightgbm.Booster:
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html#lightgbm.Booster
.. _lightgbm.Booster.save_model:
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html
#lightgbm.Booster.save_model
.. _lightgbm.train:
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.train.html#lightgbm-train
.. _scikit-learn API:
https://lightgbm.readthedocs.io/en/latest/Python-API.html#scikit-learn-api
"""
from __future__ import absolute_import
import os
import yaml
import json
import tempfile
import shutil
import inspect
import logging
import gorilla
import mlflow
from mlflow import pyfunc
from mlflow.models import Model
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.model_utils import _get_flavor_configuration
from mlflow.exceptions import MlflowException
from mlflow.utils.annotations import experimental
from mlflow.utils.autologging_utils import try_mlflow_log, log_fn_args_as_params
FLAVOR_NAME = "lightgbm"
_logger = logging.getLogger(__name__)
[docs]def get_default_conda_env():
"""
:return: The default Conda environment for MLflow Models produced by calls to
:func:`save_model()` and :func:`log_model()`.
"""
import lightgbm as lgb
return _mlflow_conda_env(
additional_conda_deps=None,
# LightGBM is not yet available via the default conda channels, so we install it via pip
additional_pip_deps=[
"lightgbm=={}".format(lgb.__version__),
],
additional_conda_channels=None)
[docs]def save_model(lgb_model, path, conda_env=None, mlflow_model=Model()):
"""
Save a LightGBM model to a path on the local file system.
:param lgb_model: LightGBM model (an instance of `lightgbm.Booster`_) to be saved.
Note that models that implement the `scikit-learn API`_ are not supported.
:param path: Local path where the model is to be saved.
:param conda_env: Either a dictionary representation of a Conda environment or the path to a
Conda environment yaml file. If provided, this describes the environment
this model should be run in. At minimum, it should specify the dependencies
contained in :func:`get_default_conda_env()`. If ``None``, the default
:func:`get_default_conda_env()` environment is added to the model.
The following is an *example* dictionary representation of a Conda
environment::
{
'name': 'mlflow-env',
'channels': ['defaults'],
'dependencies': [
'python=3.7.0',
'pip': [
'lightgbm==2.3.0'
]
]
}
:param mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
"""
import lightgbm as lgb
path = os.path.abspath(path)
if os.path.exists(path):
raise MlflowException("Path '{}' already exists".format(path))
model_data_subpath = "model.lgb"
model_data_path = os.path.join(path, model_data_subpath)
os.makedirs(path)
# Save a LightGBM model
lgb_model.save_model(model_data_path)
conda_env_subpath = "conda.yaml"
if conda_env is None:
conda_env = get_default_conda_env()
elif not isinstance(conda_env, dict):
with open(conda_env, "r") as f:
conda_env = yaml.safe_load(f)
with open(os.path.join(path, conda_env_subpath), "w") as f:
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
pyfunc.add_to_model(mlflow_model, loader_module="mlflow.lightgbm",
data=model_data_subpath, env=conda_env_subpath)
mlflow_model.add_flavor(FLAVOR_NAME, lgb_version=lgb.__version__, data=model_data_subpath)
mlflow_model.save(os.path.join(path, "MLmodel"))
[docs]def log_model(lgb_model, artifact_path, conda_env=None, registered_model_name=None, **kwargs):
"""
Log a LightGBM model as an MLflow artifact for the current run.
:param lgb_model: LightGBM model (an instance of `lightgbm.Booster`_) to be saved.
Note that models that implement the `scikit-learn API`_ are not supported.
:param artifact_path: Run-relative artifact path.
:param conda_env: Either a dictionary representation of a Conda environment or the path to a
Conda environment yaml file. If provided, this describes the environment
this model should be run in. At minimum, it should specify the dependencies
contained in :func:`get_default_conda_env()`. If ``None``, the default
:func:`get_default_conda_env()` environment is added to the model.
The following is an *example* dictionary representation of a Conda
environment::
{
'name': 'mlflow-env',
'channels': ['defaults'],
'dependencies': [
'python=3.7.0',
'pip': [
'lightgbm==2.3.0'
]
]
}
:param registered_model_name: Note:: Experimental: This argument may change or be removed in a
future release without warning. If given, create a model
version under ``registered_model_name``, also creating a
registered model if one with the given name does not exist.
:param kwargs: kwargs to pass to `lightgbm.Booster.save_model`_ method.
"""
Model.log(artifact_path=artifact_path, flavor=mlflow.lightgbm,
registered_model_name=registered_model_name,
lgb_model=lgb_model, conda_env=conda_env, **kwargs)
def _load_model(path):
import lightgbm as lgb
return lgb.Booster(model_file=path)
def _load_pyfunc(path):
"""
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
:param path: Local filesystem path to the MLflow Model with the ``lightgbm`` flavor.
"""
return _LGBModelWrapper(_load_model(path))
[docs]def load_model(model_uri):
"""
Load a LightGBM model from a local file or a run.
:param model_uri: The location, in URI format, of the MLflow model. For example:
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
For more information about supported URI schemes, see
`Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html#
artifact-locations>`_.
:return: A LightGBM model (an instance of `lightgbm.Booster`_).
"""
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
lgb_model_file_path = os.path.join(local_model_path, flavor_conf.get("data", "model.lgb"))
return _load_model(path=lgb_model_file_path)
class _LGBModelWrapper:
def __init__(self, lgb_model):
self.lgb_model = lgb_model
def predict(self, dataframe):
return self.lgb_model.predict(dataframe)
[docs]@experimental
def autolog():
"""
Enables automatic logging from LightGBM to MLflow. Logs the following.
- parameters specified in `lightgbm.train`_.
- metrics on each iteration (if ``valid_sets`` specified).
- metrics at the best iteration (if ``early_stopping_rounds`` specified).
- feature importance (both "split" and "gain") as JSON files and plots.
- trained model.
Note that the `scikit-learn API`_ is not supported.
"""
import lightgbm
import numpy as np
@gorilla.patch(lightgbm)
def train(*args, **kwargs):
def record_eval_results(eval_results):
"""
Create a callback function that records evaluation results.
"""
def callback(env):
res = {}
for data_name, eval_name, value, _ in env.evaluation_result_list:
key = data_name + '-' + eval_name
res[key] = value
eval_results.append(res)
return callback
def log_feature_importance_plot(features, importance, importance_type):
"""
Log feature importance plot.
"""
import matplotlib.pyplot as plt
indices = np.argsort(importance)
features = np.array(features)[indices]
importance = importance[indices]
num_features = len(features)
# If num_features > 10, increase the figure height to prevent the plot
# from being too dense.
w, h = [6.4, 4.8] # matplotlib's default figure size
h = h + 0.1 * num_features if num_features > 10 else h
fig, ax = plt.subplots(figsize=(w, h))
yloc = np.arange(num_features)
ax.barh(yloc, importance, align='center', height=0.5)
ax.set_yticks(yloc)
ax.set_yticklabels(features)
ax.set_xlabel('Importance')
ax.set_title('Feature Importance ({})'.format(importance_type))
fig.tight_layout()
tmpdir = tempfile.mkdtemp()
try:
# pylint: disable=undefined-loop-variable
filepath = os.path.join(tmpdir, 'feature_importance_{}.png'.format(imp_type))
fig.savefig(filepath)
try_mlflow_log(mlflow.log_artifact, filepath)
finally:
plt.close(fig)
shutil.rmtree(tmpdir)
if not mlflow.active_run():
try_mlflow_log(mlflow.start_run)
auto_end_run = True
else:
auto_end_run = False
original = gorilla.get_original_attribute(lightgbm, 'train')
# logging booster params separately via mlflow.log_params to extract key/value pairs
# and make it easier to compare them across runs.
params = args[0] if len(args) > 0 else kwargs['params']
try_mlflow_log(mlflow.log_params, params)
unlogged_params = ['params', 'train_set', 'valid_sets', 'valid_names', 'fobj', 'feval',
'init_model', 'evals_result', 'learning_rates', 'callbacks']
log_fn_args_as_params(original, args, kwargs, unlogged_params)
all_arg_names = inspect.getargspec(original)[0] # pylint: disable=W1505
num_pos_args = len(args)
# adding a callback that records evaluation results.
eval_results = []
callbacks_index = all_arg_names.index('callbacks')
callback = record_eval_results(eval_results)
if num_pos_args >= callbacks_index + 1:
tmp_list = list(args)
tmp_list[callbacks_index] += [callback]
args = tuple(tmp_list)
elif 'callbacks' in kwargs and kwargs['callbacks'] is not None:
kwargs['callbacks'] += [callback]
else:
kwargs['callbacks'] = [callback]
# training model
model = original(*args, **kwargs)
# logging metrics on each iteration.
for idx, metrics in enumerate(eval_results):
try_mlflow_log(mlflow.log_metrics, metrics, step=idx)
# If early_stopping_rounds is present, logging metrics at the best iteration
# as extra metrics with the max step + 1.
early_stopping_index = all_arg_names.index('early_stopping_rounds')
early_stopping = (num_pos_args >= early_stopping_index + 1 or
'early_stopping_rounds' in kwargs)
if early_stopping:
extra_step = len(eval_results)
try_mlflow_log(mlflow.log_metric, 'stopped_iteration', len(eval_results))
# best_iteration is set even if training does not stop early.
try_mlflow_log(mlflow.log_metric, 'best_iteration', model.best_iteration)
# iteration starts from 1 in LightGBM.
try_mlflow_log(mlflow.log_metrics, eval_results[model.best_iteration - 1],
step=extra_step)
# logging feature importance as artifacts.
for imp_type in ['split', 'gain']:
features = model.feature_name()
importance = model.feature_importance(importance_type=imp_type)
try:
log_feature_importance_plot(features, importance, imp_type)
except Exception: # pylint: disable=broad-except
_logger.exception('Failed to log feature importance plot. LightGBM autologging '
'will ignore the failure and continue. Exception: ')
imp = {ft: imp for ft, imp in zip(features, importance.tolist())}
tmpdir = tempfile.mkdtemp()
try:
filepath = os.path.join(tmpdir, 'feature_importance_{}.json'.format(imp_type))
with open(filepath, 'w') as f:
json.dump(imp, f, indent=2)
try_mlflow_log(mlflow.log_artifact, filepath)
finally:
shutil.rmtree(tmpdir)
try_mlflow_log(log_model, model, artifact_path='model')
if auto_end_run:
try_mlflow_log(mlflow.end_run)
return model
settings = gorilla.Settings(allow_hit=True, store_hit=True)
gorilla.apply(gorilla.Patch(lightgbm, 'train', train, settings=settings))