-
Notifications
You must be signed in to change notification settings - Fork 340
Implementation of Model, ModelFormat, TFLiteModelSource and subclasses #335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
573e7cb
e67bceb
1f018fe
dfe0a37
7704c44
8381ac5
a2e7544
cadd6c6
b02ea22
a0a2411
fc63db8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
deleting, publishing and unpublishing Firebase ML Kit models. | ||
""" | ||
|
||
import datetime | ||
import numbers | ||
import re | ||
import requests | ||
import six | ||
|
@@ -28,6 +30,12 @@ | |
|
||
_MLKIT_ATTRIBUTE = '_mlkit' | ||
_MAX_PAGE_SIZE = 100 | ||
_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') | ||
_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') | ||
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') | ||
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') | ||
_RESOURCE_NAME_PATTERN = re.compile( | ||
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$') | ||
|
||
|
||
def _get_mlkit_service(app): | ||
|
@@ -47,7 +55,7 @@ def _get_mlkit_service(app): | |
|
||
def get_model(model_id, app=None): | ||
mlkit_service = _get_mlkit_service(app) | ||
return Model(mlkit_service.get_model(model_id)) | ||
return Model(**mlkit_service.get_model(model_id)) | ||
|
||
|
||
def list_models(list_filter=None, page_size=None, page_token=None, app=None): | ||
|
@@ -63,9 +71,18 @@ def delete_model(model_id, app=None): | |
|
||
class Model(object): | ||
"""A Firebase ML Kit Model object.""" | ||
def __init__(self, data): | ||
"""Created from a data dictionary.""" | ||
self._data = data | ||
def __init__(self, display_name=None, tags=None, model_format=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This constructor seems to be doing a lot. Can we simplify as follows?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
self._data = kwargs | ||
if display_name is not None: | ||
self._data['displayName'] = _validate_display_name(display_name) | ||
if tags is not None: | ||
self._data['tags'] = _validate_tags(tags) | ||
if model_format is not None: | ||
_validate_model_format(model_format) | ||
if isinstance(model_format, TFLiteFormat): | ||
self._data['tfliteModel'] = model_format.as_dict() | ||
else: | ||
raise TypeError('Unsupported model format type.') | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, self.__class__): | ||
|
@@ -77,14 +94,184 @@ def __ne__(self, other): | |
return not self.__eq__(other) | ||
|
||
@property | ||
def name(self): | ||
return self._data['name'] | ||
def model_id(self): | ||
if not self._data.get('name'): | ||
return None | ||
_, model_id = _validate_and_parse_name(self._data.get('name')) | ||
return model_id | ||
|
||
@property | ||
def display_name(self): | ||
return self._data['displayName'] | ||
return self._data.get('displayName') | ||
|
||
#TODO(ifielker): define the rest of the Model properties etc | ||
@display_name.setter | ||
def display_name(self, display_name): | ||
self._data['displayName'] = _validate_display_name(display_name) | ||
return self | ||
|
||
@property | ||
def create_time(self): | ||
"""Returns the creation timestamp""" | ||
create_time = self._data.get('createTime') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self._data.get('createTime', {}).get('seconds') There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if not create_time: | ||
return None | ||
|
||
seconds = create_time.get('seconds') | ||
if not seconds: | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return None | ||
if not isinstance(seconds, numbers.Number): | ||
return None | ||
|
||
return datetime.datetime.fromtimestamp(float(seconds)) | ||
|
||
@property | ||
def update_time(self): | ||
"""Returns the last update timestamp""" | ||
update_time = self._data.get('updateTime') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if not update_time: | ||
return None | ||
|
||
seconds = update_time.get('seconds') | ||
if not seconds: | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return None | ||
if not isinstance(seconds, numbers.Number): | ||
return None | ||
|
||
return datetime.datetime.fromtimestamp(float(seconds)) | ||
|
||
@property | ||
def validation_error(self): | ||
return self._data.get('state') and \ | ||
self._data.get('state').get('validationError') and \ | ||
self._data.get('state').get('validationError').get('message') | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@property | ||
def published(self): | ||
return bool(self._data.get('state') and | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._data.get('state').get('published')) | ||
|
||
@property | ||
def etag(self): | ||
return self._data.get('etag') | ||
|
||
@property | ||
def model_hash(self): | ||
return self._data.get('modelHash') | ||
|
||
@property | ||
def tags(self): | ||
return self._data.get('tags') | ||
|
||
@tags.setter | ||
def tags(self, tags): | ||
self._data['tags'] = _validate_tags(tags) | ||
return self | ||
|
||
@property | ||
def locked(self): | ||
return bool(self._data.get('activeOperations') and | ||
len(self._data.get('activeOperations')) > 0) | ||
|
||
@property | ||
def model_format(self): | ||
if self._data.get('tfliteModel'): | ||
return TFLiteFormat(**self._data.get('tfliteModel')) | ||
return None | ||
|
||
@model_format.setter | ||
def model_format(self, model_format): | ||
if not isinstance(model_format, TFLiteFormat): | ||
raise TypeError('Unsupported model format type.') | ||
self._data['tfliteModel'] = model_format.as_dict() | ||
return self | ||
|
||
def as_dict(self): | ||
return self._data | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class ModelFormat(object): | ||
"""Abstract base class representing a Model Format such as TFLite.""" | ||
def as_dict(self): | ||
raise NotImplementedError | ||
|
||
|
||
class TFLiteFormat(ModelFormat): | ||
"""Model format representing a TFLite model.""" | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, model_source=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify here too by adding a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
self._data = kwargs | ||
if model_source is not None: | ||
# Check for correct base type | ||
if not isinstance(model_source, TFLiteModelSource): | ||
raise TypeError('Model source must be a ModelSource object.') | ||
# Set based on specific sub type | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(model_source, TFLiteGCSModelSource): | ||
self._data['gcsTfliteUri'] = model_source.as_dict() | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise TypeError('Unsupported model source type.') | ||
|
||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __eq__(self, other): | ||
if isinstance(other, self.__class__): | ||
return self._data == other._data # pylint: disable=protected-access | ||
else: | ||
return False | ||
|
||
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
||
@property | ||
def model_source(self): | ||
if self._data.get('gcsTfliteUri'): | ||
return TFLiteGCSModelSource(self._data.get('gcsTfliteUri')) | ||
return None | ||
|
||
@model_source.setter | ||
def model_source(self, model_source): | ||
if model_source is not None: | ||
if isinstance(model_source, TFLiteGCSModelSource): | ||
self._data['gcsTfliteUri'] = model_source.as_dict() | ||
else: | ||
raise TypeError('Unsupported model source type.') | ||
|
||
@property | ||
def size_bytes(self): | ||
return self._data.get('sizeBytes') | ||
|
||
def as_dict(self): | ||
return self._data | ||
|
||
|
||
class TFLiteModelSource(object): | ||
"""Abstract base class representing a model source for TFLite format models.""" | ||
def as_dict(self): | ||
raise NotImplementedError | ||
|
||
|
||
class TFLiteGCSModelSource(TFLiteModelSource): | ||
"""TFLite model source representing a tflite model file stored in GCS.""" | ||
def __init__(self, gcs_tflite_uri): | ||
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, self.__class__): | ||
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access | ||
else: | ||
return False | ||
|
||
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
||
@property | ||
def gcs_tflite_uri(self): | ||
return self._gcs_tflite_uri | ||
|
||
@gcs_tflite_uri.setter | ||
def gcs_tflite_uri(self, gcs_tflite_uri): | ||
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) | ||
|
||
def as_dict(self): | ||
return self._gcs_tflite_uri | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#TODO(ifielker): implement from_saved_model etc. | ||
|
||
|
||
class ListModelsPage(object): | ||
|
@@ -105,7 +292,7 @@ def __init__(self, list_models_func, list_filter, page_size, page_token): | |
@property | ||
def models(self): | ||
"""A list of Models from this page.""" | ||
return [Model(model) for model in self._list_response.get('models', [])] | ||
return [Model(**model) for model in self._list_response.get('models', [])] | ||
|
||
@property | ||
def list_filter(self): | ||
|
@@ -179,13 +366,49 @@ def __iter__(self): | |
return self | ||
|
||
|
||
def _validate_and_parse_name(name): | ||
# The resource name is added automatically from API call responses. | ||
# The only way it could be invalid is if someone tries to | ||
# create a model from a dictionary manually and does it incorrectly. | ||
matcher = _RESOURCE_NAME_PATTERN.match(name) | ||
if not matcher: | ||
raise ValueError('Model resource name format is invalid.') | ||
return matcher.group('project_id'), matcher.group('model_id') | ||
|
||
|
||
def _validate_model_id(model_id): | ||
if not isinstance(model_id, six.string_types): | ||
raise TypeError('Model ID must be a string.') | ||
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): | ||
if not _MODEL_ID_PATTERN.match(model_id): | ||
raise ValueError('Model ID format is invalid.') | ||
|
||
|
||
def _validate_display_name(display_name): | ||
if not _DISPLAY_NAME_PATTERN.match(display_name): | ||
raise ValueError('Display name format is invalid.') | ||
return display_name | ||
|
||
|
||
def _validate_tags(tags): | ||
if not isinstance(tags, list) or not \ | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
all(isinstance(tag, six.string_types) for tag in tags): | ||
raise TypeError('Tags must be a list of strings.') | ||
if not all(_TAG_PATTERN.match(tag) for tag in tags): | ||
raise ValueError('Tag format is invalid.') | ||
return tags | ||
|
||
|
||
def _validate_gcs_tflite_uri(uri): | ||
# GCS Bucket naming rules are complex. The regex is not comprehensive. | ||
# See https://cloud.google.com/storage/docs/naming for full details. | ||
if not _GCS_TFLITE_URI_PATTERN.match(uri): | ||
raise ValueError('GCS TFLite URI format is invalid.') | ||
return uri | ||
|
||
def _validate_model_format(model_format): | ||
if model_format: | ||
if not isinstance(model_format, ModelFormat): | ||
raise TypeError('Model format must be a ModelFormat object.') | ||
return model_format | ||
|
||
def _validate_list_filter(list_filter): | ||
if list_filter is not None: | ||
if not isinstance(list_filter, six.string_types): | ||
|
Uh oh!
There was an error while loading. Please reload this page.