tf.keras.wrappers.scikit_learn.KerasClassifier

Class KerasClassifier

Defined in tensorflow/python/keras/wrappers/scikit_learn.py.

Implementation of the scikit-learn classifier API for Keras.

__init__

__init__(
    build_fn=None,
    **sk_params
)

Initialize self. See help(type(self)) for accurate signature.

Methods

tf.keras.wrappers.scikit_learn.KerasClassifier.check_params

check_params(params)

Checks for user typos in params.

Arguments:

  • params: dictionary; the parameters to be checked

Raises:

  • ValueError: if any member of params is not a valid argument.

tf.keras.wrappers.scikit_learn.KerasClassifier.filter_sk_params

filter_sk_params(
    fn,
    override=None
)

Filters sk_params and returns those in fn's arguments.

Arguments:

  • fn: arbitrary function
  • override: dictionary, values to override sk_params

Returns:

  • res: dictionary containing variables in both sk_params and fn's arguments.

tf.keras.wrappers.scikit_learn.KerasClassifier.fit

fit(
    x,
    y,
    **kwargs
)

Constructs a new model with build_fn & fit the model to (x, y).

Arguments:

  • x: array-like, shape (n_samples, n_features) Training samples where n_samples is the number of samples and n_features is the number of features.
  • y: array-like, shape (n_samples,) or (n_samples, n_outputs) True labels for x.
  • **kwargs: dictionary arguments Legal arguments are the arguments of Sequential.fit

Returns:

  • history: object details about the training history at each epoch.

Raises:

  • ValueError: In case of invalid shape for y argument.

tf.keras.wrappers.scikit_learn.KerasClassifier.get_params

get_params(**params)

Gets parameters for this estimator.

Arguments:

  • **params: ignored (exists for API compatibility).

Returns:

Dictionary of parameter names mapped to their values.

tf.keras.wrappers.scikit_learn.KerasClassifier.predict

predict(
    x,
    **kwargs
)

Returns the class predictions for the given test data.

Arguments:

  • x: array-like, shape (n_samples, n_features) Test samples where n_samples is the number of samples and n_features is the number of features.
  • **kwargs: dictionary arguments Legal arguments are the arguments of Sequential.predict_classes.

Returns:

  • preds: array-like, shape (n_samples,) Class predictions.

tf.keras.wrappers.scikit_learn.KerasClassifier.predict_proba

predict_proba(
    x,
    **kwargs
)

Returns class probability estimates for the given test data.

Arguments:

  • x: array-like, shape (n_samples, n_features) Test samples where n_samples is the number of samples and n_features is the number of features.
  • **kwargs: dictionary arguments Legal arguments are the arguments of Sequential.predict_classes.

Returns:

  • proba: array-like, shape (n_samples, n_outputs) Class probability estimates. In the case of binary classification, to match the scikit-learn API, will return an array of shape (n_samples, 2) (instead of (n_sample, 1) as in Keras).

tf.keras.wrappers.scikit_learn.KerasClassifier.score

score(
    x,
    y,
    **kwargs
)

Returns the mean accuracy on the given test data and labels.

Arguments:

  • x: array-like, shape (n_samples, n_features) Test samples where n_samples is the number of samples and n_features is the number of features.
  • y: array-like, shape (n_samples,) or (n_samples, n_outputs) True labels for x.
  • **kwargs: dictionary arguments Legal arguments are the arguments of Sequential.evaluate.

Returns:

  • score: float Mean accuracy of predictions on x wrt. y.

Raises:

  • ValueError: If the underlying model isn't configured to compute accuracy. You should pass metrics=["accuracy"] to the .compile() method of the model.

tf.keras.wrappers.scikit_learn.KerasClassifier.set_params

set_params(**params)

Sets the parameters of this estimator.

Arguments:

  • **params: Dictionary of parameter names mapped to their values.

Returns:

self