-
Notifications
You must be signed in to change notification settings - Fork 115
Customized Feature Column
TF2.0 users generally define model using Keras. tf.keras.layers.Embedding is the native embedding layer in Keras. It can only handle the dense input. For sparse inputs, user often use tf.feature_column.embedding_columns to convert them to dense representation to feed to a DNN.
For ElasticDL, user define the model using keras. And we have provide the elastic.layers.Embedding to interact with the ElasticDL parameter server and partition the embedding table among multiple PS instances. It can replace the native keras embedding layer but can't replace the embedding_column.
In this doc, we are focuing on how to write a customized feature column to interact with the parameter server and how to replace the native feature column with ours.
The following sample code is about how to replace embedding_column with indicator_column.
import tensorflow as tf
from tensorflow import feature_column
from tensorflow.python.feature_column import feature_column_v2 as fc_lib
def replace_embedding_column_with_indicator_column_in_feature_layer(feature_layer):
new_feature_columns = []
for column in feature_layer._feature_columns:
if isinstance(column, fc_lib.EmbeddingColumn):
new_column = tf.feature_column.indicator_column(column.categorical_column)
new_feature_columns.append(new_column)
else:
new_feature_columns.append(column)
feature_layer._feature_columns = new_feature_columns
return feature_layer
- Define a new class inherits from FeatureColumn. What's more, we want to customized a embedding column, so it need inherts from DenseColumn.
- Implement all the abstract methods. Especially we focus on the following two methods:
create_state
Create the variable for this FeatureColumn associated with the DenseFeature layer, such as the embedding variables.
get_dense_tensor
While executing DenseFeature.call, it will iterate all the feature column elements and call get_dense_tensor to get the transformed dense tensor from the feature columns. Let's take native EmbeddingColumn for example, it will callembedding_ops.safe_embedding_lookup_sparse
to get the embedding vectors from the sparse input.