tf.nn.max_pool_with_argmax

View source on GitHub

Performs max pooling on the input and outputs both max values and indices.

tf.nn.max_pool_with_argmax(
    input, ksize, strides, padding, data_format='NHWC',
    output_dtype=tf.dtypes.int64, include_batch_in_index=False, name=None
)

The indices in argmax are flattened, so that a maximum value at position [b, y, x, c] becomes flattened index: (y * width + x) * channels + c if include_batch_in_index is False; ((b * height + y) * width + x) * channels + c if include_batch_in_index is True.

The indices returned are always in [0, height) x [0, width) before flattening, even if padding is involved and the mathematically correct answer is outside (either negative or too large). This is a bug, but fixing it is difficult to do in a safe backwards compatible way, especially due to flattening.

Args:

Returns:

A tuple of Tensor objects (output, argmax).