tf.keras.wrappers.scikit_learn.KerasRegressor

View source on GitHub

Implementation of the scikit-learn regressor API for Keras.

tf.keras.wrappers.scikit_learn.KerasRegressor(
    build_fn=None, **sk_params
)

Methods

check_params

View source

check_params(
    params
)

Checks for user typos in params.

Arguments:

Raises:

filter_sk_params

View source

filter_sk_params(
    fn, override=None
)

Filters sk_params and returns those in fn's arguments.

Arguments:

Returns:

fit

View source

fit(
    x, y, **kwargs
)

Constructs a new model with build_fn & fit the model to (x, y).

Arguments:

Returns:

get_params

View source

get_params(
    **params
)

Gets parameters for this estimator.

Arguments:

Returns:

Dictionary of parameter names mapped to their values.

predict

View source

predict(
    x, **kwargs
)

Returns predictions for the given test data.

Arguments:

Returns:

score

View source

score(
    x, y, **kwargs
)

Returns the mean loss on the given test data and labels.

Arguments:

Returns:

set_params

View source

set_params(
    **params
)

Sets the parameters of this estimator.

Arguments:

Returns:

self