-
Notifications
You must be signed in to change notification settings - Fork 339
Implementation of Model, ModelFormat, TFLiteModelSource and subclasses #335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 1 commit
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
573e7cb
Implementation of Model, ModelFormat, ModelSource and subclasses
ifielker e67bceb
review fixes
ifielker 1f018fe
more review fixes
ifielker dfe0a37
review fixes 2
ifielker 7704c44
review fixes 3
ifielker 8381ac5
review fixes 4
ifielker a2e7544
review fixes 5
ifielker cadd6c6
fixed lint
ifielker b02ea22
review comments
ifielker a0a2411
more review changes
ifielker fc63db8
fixed lint
ifielker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
deleting, publishing and unpublishing Firebase ML Kit models. | ||
""" | ||
|
||
import datetime | ||
import numbers | ||
import re | ||
import requests | ||
import six | ||
|
@@ -63,9 +65,25 @@ def delete_model(model_id, app=None): | |
|
||
class Model(object): | ||
"""A Firebase ML Kit Model object.""" | ||
def __init__(self, data): | ||
def __init__(self, data=None, display_name=None, tags=None, model_format=None): | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Created from a data dictionary.""" | ||
self._data = data | ||
if data is not None and isinstance(data, dict): | ||
self._data = data | ||
else: | ||
self._data = {} | ||
if display_name is not None: | ||
_validate_display_name(display_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can save a few lines if you get these validators to return the validated value.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
self._data['displayName'] = display_name | ||
if tags is not None: | ||
_validate_tags(tags) | ||
self._data['tags'] = tags | ||
if model_format is not None: | ||
_validate_model_format(model_format) | ||
if isinstance(model_format, TFLiteFormat): | ||
self._data['tfliteModel'] = model_format.get_json() | ||
else: | ||
raise TypeError('Unsupported model format type.') | ||
|
||
|
||
def __eq__(self, other): | ||
if isinstance(other, self.__class__): | ||
|
@@ -77,15 +95,181 @@ def __ne__(self, other): | |
return not self.__eq__(other) | ||
|
||
@property | ||
def name(self): | ||
return self._data['name'] | ||
def model_id(self): | ||
if not self._data.get('name'): | ||
return None | ||
_, model_id = _validate_and_parse_name(self._data.get('name')) | ||
return model_id | ||
|
||
@property | ||
def display_name(self): | ||
return self._data['displayName'] | ||
return self._data.get('displayName') | ||
|
||
@display_name.setter | ||
def display_name(self, display_name): | ||
_validate_display_name(display_name) | ||
self._data['displayName'] = display_name | ||
return self | ||
|
||
@property | ||
def create_time(self): | ||
if self._data.get('createTime') and \ | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._data.get('createTime').get('seconds') and \ | ||
isinstance(self._data.get('createTime').get('seconds'), numbers.Number): | ||
return datetime.datetime.fromtimestamp( | ||
float(self._data.get('createTime').get('seconds'))) | ||
return None | ||
|
||
@property | ||
def update_time(self): | ||
if self._data.get('updateTime') and \ | ||
self._data.get('updateTime').get('seconds') and \ | ||
isinstance(self._data.get('updateTime').get('seconds'), numbers.Number): | ||
return datetime.datetime.fromtimestamp( | ||
float(self._data.get('updateTime').get('seconds'))) | ||
return None | ||
|
||
@property | ||
def validation_error(self): | ||
return self._data.get('state') and \ | ||
self._data.get('state').get('validationError') and \ | ||
self._data.get('state').get('validationError').get('message') | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@property | ||
def published(self): | ||
return bool(self._data.get('state') and | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._data.get('state').get('published')) | ||
|
||
@property | ||
def etag(self): | ||
return self._data.get('etag') | ||
|
||
@property | ||
def model_hash(self): | ||
return self._data.get('modelHash') | ||
|
||
@property | ||
def tags(self): | ||
return self._data.get('tags') | ||
|
||
@tags.setter | ||
def tags(self, tags): | ||
_validate_tags(tags) | ||
self._data['tags'] = tags | ||
return self | ||
|
||
@property | ||
def locked(self): | ||
return bool(self._data.get('activeOperations') and | ||
len(self._data.get('activeOperations')) > 0) | ||
|
||
@property | ||
def model_format(self): | ||
if self._data.get('tfliteModel'): | ||
return TFLiteFormat(self._data.get('tfliteModel')) | ||
return None | ||
|
||
@model_format.setter | ||
def model_format(self, model_format): | ||
if not isinstance(model_format, TFLiteFormat): | ||
raise TypeError('Unsupported model format type.') | ||
self._data['tfliteModel'] = model_format.get_json() | ||
return self | ||
|
||
def get_json(self): | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self._data | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class ModelFormat(object): | ||
"""Abstract base class representing a Model Format such as TFLite.""" | ||
def get_json(self): | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise NotImplementedError | ||
|
||
|
||
class TFLiteFormat(ModelFormat): | ||
"""Model format representing a TFLite model.""" | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, data=None, model_source=None): | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (data is not None) and isinstance(data, dict): | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._data = data | ||
else: | ||
self._data = {} | ||
if model_source is not None: | ||
# Check for correct base type | ||
if not isinstance(model_source, TFLiteModelSource): | ||
raise TypeError('Model source must be a ModelSource object.') | ||
# Set based on specific sub type | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(model_source, TFLiteGCSModelSource): | ||
self._data['gcsTfliteUri'] = model_source.get_json() | ||
else: | ||
raise TypeError('Unsupported model source type.') | ||
|
||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __eq__(self, other): | ||
if isinstance(other, self.__class__): | ||
return self._data == other._data # pylint: disable=protected-access | ||
else: | ||
return False | ||
|
||
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
||
@property | ||
def model_source(self): | ||
if self._data.get('gcsTfliteUri'): | ||
return TFLiteGCSModelSource(self._data.get('gcsTfliteUri')) | ||
return None | ||
|
||
@model_source.setter | ||
def model_source(self, model_source): | ||
if model_source is not None: | ||
if isinstance(model_source, TFLiteGCSModelSource): | ||
self._data['gcsTfliteUri'] = model_source.get_json() | ||
else: | ||
raise TypeError('Unsupported model source type.') | ||
|
||
|
||
@property | ||
def size_bytes(self): | ||
return self._data.get('sizeBytes') | ||
|
||
def get_json(self): | ||
return self._data | ||
|
||
|
||
class TFLiteModelSource(object): | ||
"""Abstract base class representing a model source for TFLite format models.""" | ||
def get_json(self): | ||
raise NotImplementedError | ||
|
||
|
||
class TFLiteGCSModelSource(TFLiteModelSource): | ||
"""TFLite model source representing a tflite model file stored in GCS.""" | ||
def __init__(self, gcs_tflite_uri): | ||
_validate_gcs_tflite_uri(gcs_tflite_uri) | ||
self._gcs_tflite_uri = gcs_tflite_uri | ||
|
||
#TODO(ifielker): define the rest of the Model properties etc | ||
def __eq__(self, other): | ||
if isinstance(other, self.__class__): | ||
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access | ||
else: | ||
return False | ||
|
||
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
||
@property | ||
def gcs_tflite_uri(self): | ||
return self._gcs_tflite_uri | ||
|
||
@gcs_tflite_uri.setter | ||
def gcs_tflite_uri(self, gcs_tflite_uri): | ||
_validate_gcs_tflite_uri(gcs_tflite_uri) | ||
self._gcs_tflite_uri = gcs_tflite_uri | ||
|
||
def get_json(self): | ||
return self._gcs_tflite_uri | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#TODO(ifielker): implement from_saved_model etc. | ||
|
||
class ListModelsPage(object): | ||
"""Represents a page of models in a firebase project. | ||
|
@@ -179,13 +363,55 @@ def __iter__(self): | |
return self | ||
|
||
|
||
def _validate_and_parse_name(name): | ||
# The resource name is added automatically from API call responses. | ||
# The only way it could be invalid is if someone tries to | ||
# create a model from a dictionary manually and does it incorrectly. | ||
if not isinstance(name, six.string_types): | ||
raise TypeError('Model resource name must be a string.') | ||
matcher = re.match( | ||
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$', | ||
name) | ||
if not matcher: | ||
raise ValueError('Model resource name format is invalid.') | ||
return matcher.group('project_id'), matcher.group('model_id') | ||
|
||
|
||
def _validate_model_id(model_id): | ||
if not isinstance(model_id, six.string_types): | ||
raise TypeError('Model ID must be a string.') | ||
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): | ||
raise ValueError('Model ID format is invalid.') | ||
|
||
|
||
def _validate_display_name(display_name): | ||
if not isinstance(display_name, six.string_types): | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise TypeError('Display name must be a string.') | ||
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', display_name): | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError('Display name format is invalid.') | ||
|
||
|
||
def _validate_tags(tags): | ||
if not isinstance(tags, list) or not \ | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
all(isinstance(tag, six.string_types) for tag in tags): | ||
raise TypeError('Tags must be a list of strings.') | ||
if not all(re.match(r'^[A-Za-z0-9_-]{1,60}$', tag) for tag in tags): | ||
raise ValueError('Tag format is invalid.') | ||
|
||
|
||
def _validate_gcs_tflite_uri(uri): | ||
if not isinstance(uri, six.string_types): | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise TypeError('Gcs TFLite URI must be a string.') | ||
# GCS Bucket naming rules are complex. The regex is not comprehensive. | ||
# See https://cloud.google.com/storage/docs/naming for full details. | ||
if not re.match(r'^gs://[a-z0-9_.-]{3,63}/.+', uri): | ||
raise ValueError('GCS TFLite URI format is invalid.') | ||
|
||
def _validate_model_format(model_format): | ||
if model_format is not None: | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not isinstance(model_format, ModelFormat): | ||
raise TypeError('Model format must be a ModelFormat object.') | ||
|
||
def _validate_list_filter(list_filter): | ||
if list_filter is not None: | ||
if not isinstance(list_filter, six.string_types): | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.