tf.compat.v1.tpu.experimental.shared_embedding_columns

View source on GitHub

TPU version of tf.compat.v1.feature_column.shared_embedding_columns.

tf.compat.v1.tpu.experimental.shared_embedding_columns(
    categorical_columns, dimension, combiner='mean', initializer=None,
    shared_embedding_collection_name=None, max_sequence_lengths=None,
    learning_rate_fn=None, embedding_lookup_device=None, tensor_core_shape=None
)

Note that the interface for tf.tpu.experimental.shared_embedding_columns is different from that of tf.compat.v1.feature_column.shared_embedding_columns: The following arguments are NOT supported: ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable.

Use this function in place of tf.compat.v1.feature_column.shared_embedding_columns` when you want to use the TPU to accelerate your embedding lookups via TPU embeddings.

column_a = tf.feature_column.categorical_column_with_identity(...)
column_b = tf.feature_column.categorical_column_with_identity(...)
tpu_columns = tf.tpu.experimental.shared_embedding_columns(
    [column_a, column_b], 10)
...
def model_fn(features):
  dense_feature = tf.keras.layers.DenseFeature(tpu_columns)
  embedded_feature = dense_feature(features)
  ...

estimator = tf.estimator.tpu.TPUEstimator(
    model_fn=model_fn,
    ...
    embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
        column=tpu_columns,
        ...))

Args:

Returns:

A list of _TPUSharedEmbeddingColumnV2.

Raises: