Skip to content

Adding support for TensorFlow 2.x #372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 41 additions & 19 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import six


from six.moves import urllib
from firebase_admin import _http_client
from firebase_admin import _utils
from firebase_admin import exceptions
Expand Down Expand Up @@ -200,6 +201,7 @@ def from_dict(cls, data, app=None):
data_copy = dict(data)
tflite_format = None
tflite_format_data = data_copy.pop('tfliteModel', None)
data_copy.pop('@type', None) # Returned by Operations. (Not needed)
if tflite_format_data:
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
model = Model(model_format=tflite_format)
Expand Down Expand Up @@ -495,12 +497,31 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app)

@staticmethod
def _assert_tf_version_1_enabled():
def _assert_tf_enabled():
if not _TF_ENABLED:
raise ImportError('Failed to import the tensorflow library for Python. Make sure '
'to install the tensorflow module.')
if not tf.VERSION.startswith('1.'):
raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION))
if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'):
raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}'
.format(tf.version.VERSION))

@staticmethod
def _tf_convert_from_saved_model(saved_model_dir):
# Same for both v1.x and v2.x
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
return converter.convert()

@staticmethod
def _tf_convert_from_keras_model(keras_model):
# Version 1.x conversion function takes a model file. Version 2.x takes the model itself.
if tf.version.VERSION.startswith('1.'):
keras_file = 'firebase_keras_model.h5'
tf.keras.models.save_model(keras_model, keras_file)
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
return converter.convert()
else:
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
return converter.convert()

@classmethod
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
Expand All @@ -518,9 +539,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
TFLiteGCSModelSource._assert_tf_version_1_enabled()
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
TFLiteGCSModelSource._assert_tf_enabled()
tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir)
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
'firebase_mlkit_model.tflite', bucket_name, app)
Expand All @@ -541,11 +561,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None):
Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
TFLiteGCSModelSource._assert_tf_version_1_enabled()
keras_file = 'keras_model.h5'
tf.keras.models.save_model(keras_model, keras_file)
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
TFLiteGCSModelSource._assert_tf_enabled()
tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model)
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
'firebase_mlkit_model.tflite', bucket_name, app)
Expand Down Expand Up @@ -852,12 +869,12 @@ def create_model(self, model):

def update_model(self, model, update_mask=None):
_validate_model(model, update_mask)
data = {'model': model.as_dict(for_upload=True)}
path = 'models/{0}'.format(model.model_id)
if update_mask is not None:
data['updateMask'] = update_mask
path = path + '?updateMask={0}'.format(update_mask)
try:
return self.handle_operation(
self._client.body('patch', url='models/{0}'.format(model.model_id), json=data))
self._client.body('patch', url=path, json=model.as_dict(for_upload=True)))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

Expand All @@ -884,15 +901,20 @@ def list_models(self, list_filter, page_size, page_token):
_validate_list_filter(list_filter)
_validate_page_size(page_size)
_validate_page_token(page_token)
payload = {}
params = {}
if list_filter:
payload['list_filter'] = list_filter
params['filter'] = list_filter
if page_size:
payload['page_size'] = page_size
params['page_size'] = page_size
if page_token:
payload['page_token'] = page_token
params['page_token'] = page_token
path = 'models'
if params:
# pylint: disable=too-many-function-args
param_str = urllib.parse.urlencode(sorted(params.items()), True)
path = path + '?' + param_str
try:
return self._client.body('get', url='models', json=payload)
return self._client.body('get', url=path)
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

Expand Down
27 changes: 15 additions & 12 deletions tests/test_mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,13 @@ def teardown_class(cls):
testutils.cleanup_apps()

@staticmethod
def _url(project_id, model_id):
def _update_url(project_id, model_id):
update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format(
project_id, model_id)
return BASE_URL + update_url

@staticmethod
def _get_url(project_id, model_id):
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)

@staticmethod
Expand All @@ -778,10 +784,9 @@ def test_immediate_done(self, publish_function, published):
assert model == CREATED_UPDATED_MODEL_1
assert len(recorder) == 1
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
body = json.loads(recorder[0].body.decode())
assert body.get('model', {}).get('state', {}).get('published', None) is published
assert body.get('updateMask', {}) == 'state.published'
assert body.get('state', {}).get('published', None) is published

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_returns_locked(self, publish_function):
Expand All @@ -794,9 +799,9 @@ def test_returns_locked(self, publish_function):
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1)

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_operation_error(self, publish_function):
Expand Down Expand Up @@ -973,12 +978,10 @@ def test_list_models_with_all_args(self):
page_token=PAGE_TOKEN)
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == TestListModels._url(PROJECT_ID)
assert json.loads(recorder[0].body.decode()) == {
'list_filter': 'display_name=displayName3',
'page_size': 10,
'page_token': PAGE_TOKEN
}
assert recorder[0].url == (
TestListModels._url(PROJECT_ID) +
'?filter=display_name%3DdisplayName3&page_size=10&page_token={0}'
.format(PAGE_TOKEN))
assert isinstance(models_page, mlkit.ListModelsPage)
assert len(models_page.models) == 1
assert models_page.models[0] == MODEL_3
Expand Down