Skip to content

Customized Feature Column

brightcoder01 edited this page Nov 24, 2019 · 15 revisions

Customized Feature Column

TF2.0 users generally define model using Keras. tf.keras.layers.Embedding is the native embedding layer in Keras. It can only handle the dense input. For sparse inputs, user often use tf.feature_column.embedding_columns to convert them to dense representation to feed to a DNN.

For ElasticDL, user define the model using keras. And we have provide the elastic.layers.Embedding to interact with the ElasticDL parameter server and partition the embedding table among multiple PS instances. It can replace the native keras embedding layer but can't replace the embedding_column.

In this doc, we are focuing on how to write a customized feature column to interact with the parameter server and how to replace the native feature column with ours.

How to replace feature column

The following sample code is about how to replace embedding_column with indicator_column.

import tensorflow as tf
from tensorflow import feature_column
from tensorflow.python.feature_column import feature_column_v2 as fc_lib

def replace_embedding_column_with_indicator_column_in_feature_layer(feature_layer):
    new_feature_columns = []
    for column in feature_layer._feature_columns:
        if isinstance(column, fc_lib.EmbeddingColumn):
            new_column = tf.feature_column.indicator_column(column.categorical_column)
            new_feature_columns.append(new_column)
        else:
            new_feature_columns.append(column)

    feature_layer._feature_columns = new_feature_columns

    return feature_layer

How to write a customzied feature column

  1. Define a new class inherits from FeatureColumn. What's more, we want to customized a embedding column, so it need inherts from DenseColumn.
  2. Implement all the abstract methods. Especially we focus on the following two methods: create_state Create the variable for this FeatureColumn associated with the DenseFeature layer, such as the embedding variables. get_dense_tensor While executing DenseFeature.call, it will iterate all the feature column elements and call get_dense_tensor to get the transformed dense tensor from the feature columns. Let's take native EmbeddingColumn for example, it will call embedding_ops.safe_embedding_lookup_sparse to get the embedding vectors from the sparse input.
Clone this wiki locally