-
Notifications
You must be signed in to change notification settings - Fork 115
ElasticDL Serving Solution Explore
Besides training, model serving is an essential part an the end-to-end machine learning lifecycle. Publishing the trained model as a service in production can make the model valuable in the real world.
At the current stage, ElasticDL focuses on the training part. We don't have our own or can't reuse any existed serving infrastructure to serve our trained models. (Why?) Our target is to figure out the serving solution.
Store the ElasticDL model in the SavedModel format.
SavedModel is the universal serialization format for tensorflow models. It's language neutral and can be loaded by multiple frameworks (such as TFServing, TFLite, TensorFlow.js and so on). We choose to store the ElaticDL model into SavedModel format. In this way, we can leverage various mature solutions to serving our model in different scenarios.
The model size varies from several kilobytes to several terabytes. We divide the model size into two categories: Small or medium size and large size. The small or medium size model can be loaded by a process, and the latter can not fit in a single process. Training and serving strategies will be different between these two cases. Please check the following table:
Master Central Storage | AllReduce | Parameter Server | |
---|---|---|---|
Small or Medium Size Model | SavedModel | SavedModel | SavedModel |
Large Size Model | N/A | N/A | Distributed Parameter Server for Serving |
Distributed Parameter Server for Serving
This is for the case that the model can't fit in a single process. We partition the model variables into multiple shards, store them in distributed parameter server for serving. In the serving stage, the inference engine will execute the serving graph, query the variable values from distributed parameter server as needed and finish the calculation.
The latency and SLA requirement is higher for serving compared with training. The parameter server instance count is in proportion to the QPS of the inference traffic. And for serving, the parameter server only needs look up the static embedding table. It's simpler than training. We will separate the parameter servers between training and serving.
We will consider this solution in a separate design in the next step.
-
How to save the model trained with parameter server as SavedModel?
For the model of large size, we are designing parameter server to restore the variables and embeddings. Currently we use Redis as a temporary solution. In our model definition, we use ElasticDL.Embedding instead of tf.keras.layers.Embedding to interact with our parameter server. ElasticDL.Embedding use tf.py_function to invoke Rpc to call the parameter server.
But in the stage of saving model, the customized ElasticDL.Embedding layer is not mapped to any native TensorFlow op and can't be saved into SavedModel. The embedding vectors stored in parameter server are lost. The embedding look up can't work in the serving process.
1. Customize an embedding layer to train with ElasticDL.Embedding and export model using SavedModel format with Keras.Layers.Embedding.
To verify the feasibility, we define a custom layer like this:
import os
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Embedding, Dense, Flatten
from elasticdl.python.elasticdl.layers.embedding import Embedding as elasticDL_Embedding
from tensorflow.keras import layers
from tensorflow.python.keras.utils import tf_utils
class TestCustomEmbedding(layers.Layer):
def __init__(self,
input_dim,
output_dim,
**kwargs
):
super(TestCustomEmbedding, self).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.edl_embedding_layer = elasticDL_Embedding(self.output_dim)
self.keras_embedding_layer = Embedding(self.input_dim, self.output_dim)
def call(self, inputs):
is_exporting_saved_model = s.getenv('SAVED_MODEL')=='True'
is_elastic = os.getenv('FRAMEWORK') == 'ElasticDL'
def _true_fn(inputs):
_replace_weights_with_edl()
out = self.keras_embedding_layer(inputs)
return out
def _false_fn(inputs):
if is_elastic:
out = self.edl_embedding_layer(inputs)
else:
out = self.keras_embedding_layer(inputs)
return out
def _replace_weights_with_edl():
import pandas as pd
var_values = pd.read_csv('variable.csv')
custom_param = var_values.values
for var in self.keras_embedding_layer.trainable_variables:
var.assign(custom_param)
return tf_utils.smart_cond(is_exporting_saved_model,
lambda: _true_fn(inputs),
lambda: _false_fn(inputs)
)
In the TestCustomEmbedding, the variables in Keras.Embedding instance will be replaced by values in variable.csv. The variable.csv mocks the variable values in ElasticDl.Embedding instance which can be gotten by grpc.
Then, we will define a Keras model with TestCustomEmbedding like as below:
inputs = Input(shape=(10,))
embedding = TestCustomEmbedding(10,4)(inputs)
flatten = Flatten()(embedding)
output = Dense(1, activation='sigmoid')(flatten)
model = tf.keras.Model(inputs=[inputs], outputs=[output])
os.environ['SAVED_MODEL'] = 'False'
input_array = tf.constant([[1,2,3,4,1,1,1,1,1,0]])
output = model.call(input_array, training=True)
print('training output : ', output)
output = model.call(input_array)
print('predict output : ',output)
The output
training output : tf.Tensor([[0.48767245]], shape=(1, 1), dtype=float32)
predict output : tf.Tensor([[0.48767245]], shape=(1, 1), dtype=float32)
The we set SAVE_MODEL to True and view the model output with the same input.
# save model in saved_model
tf.saved_model.save(model, "./tmp/custom_embedding/123")
os.environ['SAVED_MODEL'] = 'True'
output = model.call(input_array)
print('predict output in saved_model : ', output)
The output
predict output in saved_model : tf.Tensor([[0.99985003]], shape=(1, 1), dtype=float32)
The, we publish a service with the SavedModel by by tf-serving. Then request the server with the same input values.
curl -d '{"instances": [[1,2,3,4,1,1,1,1,1,0]]}' -X POST http://localhost:8501/v1/models/model:predict
The response
{
"predictions": [[0.999850035]]
}
So, we have verified that the custom layer can use ElasticDL.Embedding during training and use Keras.Embedding with variables in ElasticDL.Embedding to save model with SavedModel format.
2. For the Sequential model and the Model class used with the functional API, we can clone a new model by keras.models.clone_model and replace keras.layers.Embedding with Elastic.Embedding.
import tensorflow as tf
from tensorflow import keras
from elasticdl.python.elasticdl.layers.embedding import Embedding as edl_Embedding
from tensorflow.keras.layers import Input, Embedding, Flatten, Dense
def clone_function(layer):
if isinstance(layer, keras.layers.Embedding):
print(layer.output_dim)
output_dim = layer.output_dim
edl_layer = edl_Embedding(output_dim)
return edl_layer
return layer
inputs = Input(shape=(10,))
embedding = Embedding(10,4)(inputs)
flatten = Flatten()(embedding)
output = Dense(1, activation='sigmoid')(flatten)
model = tf.keras.Model(inputs=[inputs], outputs=[output])
new_model = keras.models.clone_model(model, clone_function=clone_function)
Layers in the model:
for layer in model.layers:
print(layer)
<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x13ea36b70>
<tensorflow.python.keras.layers.embeddings.Embedding object at 0x13ea36ac8>
<tensorflow.python.keras.layers.core.Flatten object at 0x13ea36278>
<tensorflow.python.keras.layers.core.Dense object at 0x13ea36fd0>
Layers in the new model
for layer in new_model.layers:
print(layer)
<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x13e74d2b0>
<elasticdl.python.elasticdl.layers.embedding.Embedding object at 0x13ea7e198>
<tensorflow.python.keras.layers.core.Flatten object at 0x13ea36278>
<tensorflow.python.keras.layers.core.Dense object at 0x13ea36fd0>
As shown, we succeed in replacing the Keras.layers.Embedding with ElasticDL.Embedding in the new model. So, we can use the new model to train in ElasticDL and use the origin model to export SavedModel with embedding variables of ElasticDL.Embedding. But, the method will not work if the custom Layer class to make embedding.
3. For subclass model, we can replace the Keras.layers.Embedding attribute with Elastic.Embedding.
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras.engine import network
from tensorflow.keras.layers import Input, Embedding, Flatten, Dense
from elasticdl.python.elasticdl.layers.embedding import Embedding as edl_Embedding
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.embedding = Embedding(10,4)
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
embedding = self.embedding(inputs)
x = self.dense1(embedding)
return self.dense2(x)
model = MyModel()
Layers in the origin model:
for layer in model.layers:
print(layer)
<tensorflow.python.keras.layers.embeddings.Embedding object at 0x13b1e6b70>
<tensorflow.python.keras.layers.core.Dense object at 0x13b1e6da0>
<tensorflow.python.keras.layers.core.Dense object at 0x13b1d76a0>
Now, we replace model.embedding with Elastic.Embedding.
for attr_name, attr_value in model.__dict__.items():
if isinstance(attr_value, keras.layers.Embedding):
setattr(model, attr_name, edl_Embedding(attr_value.output_dim))
for layer in model.layers:
print(layer)
<tensorflow.python.keras.layers.core.Dense object at 0x13b1e6da0>
<tensorflow.python.keras.layers.core.Dense object at 0x13b1d76a0>
<elasticdl.python.elasticdl.layers.embedding.Embedding object at 0x13b1d2908>
Like the 2nd solution, the solution will not work for the custom layer class to make embedding.
- Is the following scenario possible? User writes tf.keras.layer.Embedding in the model definition. While running the model in ElasticDL, if PS is turned on, the keras native Embedding layer is replaced with ElasticDL.Embedding layer to interact with parameter server. In this way, user can write the model using TensorFlow native Api, but can execute in distributed way in ElasticDL. It's more user friendly.