Skip to content

Implementation of Model, ModelFormat, TFLiteModelSource and subclasses #335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 29, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 232 additions & 6 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
deleting, publishing and unpublishing Firebase ML Kit models.
"""

import datetime
import numbers
import re
import requests
import six
Expand Down Expand Up @@ -63,9 +65,25 @@ def delete_model(model_id, app=None):

class Model(object):
"""A Firebase ML Kit Model object."""
def __init__(self, data):
def __init__(self, data=None, display_name=None, tags=None, model_format=None):
"""Created from a data dictionary."""
self._data = data
if data is not None and isinstance(data, dict):
self._data = data
else:
self._data = {}
if display_name is not None:
_validate_display_name(display_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can save a few lines if you get these validators to return the validated value.

self._data['displayName'] = _validate_display_name(display_name)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self._data['displayName'] = display_name
if tags is not None:
_validate_tags(tags)
self._data['tags'] = tags
if model_format is not None:
_validate_model_format(model_format)
if isinstance(model_format, TFLiteFormat):
self._data['tfliteModel'] = model_format.get_json()
else:
raise TypeError('Unsupported model format type.')


def __eq__(self, other):
if isinstance(other, self.__class__):
Expand All @@ -77,15 +95,181 @@ def __ne__(self, other):
return not self.__eq__(other)

@property
def name(self):
return self._data['name']
def model_id(self):
if not self._data.get('name'):
return None
_, model_id = _validate_and_parse_name(self._data.get('name'))
return model_id

@property
def display_name(self):
return self._data['displayName']
return self._data.get('displayName')

@display_name.setter
def display_name(self, display_name):
_validate_display_name(display_name)
self._data['displayName'] = display_name
return self

@property
def create_time(self):
if self._data.get('createTime') and \
self._data.get('createTime').get('seconds') and \
isinstance(self._data.get('createTime').get('seconds'), numbers.Number):
return datetime.datetime.fromtimestamp(
float(self._data.get('createTime').get('seconds')))
return None

@property
def update_time(self):
if self._data.get('updateTime') and \
self._data.get('updateTime').get('seconds') and \
isinstance(self._data.get('updateTime').get('seconds'), numbers.Number):
return datetime.datetime.fromtimestamp(
float(self._data.get('updateTime').get('seconds')))
return None

@property
def validation_error(self):
return self._data.get('state') and \
self._data.get('state').get('validationError') and \
self._data.get('state').get('validationError').get('message')

@property
def published(self):
return bool(self._data.get('state') and
self._data.get('state').get('published'))

@property
def etag(self):
return self._data.get('etag')

@property
def model_hash(self):
return self._data.get('modelHash')

@property
def tags(self):
return self._data.get('tags')

@tags.setter
def tags(self, tags):
_validate_tags(tags)
self._data['tags'] = tags
return self

@property
def locked(self):
return bool(self._data.get('activeOperations') and
len(self._data.get('activeOperations')) > 0)

@property
def model_format(self):
if self._data.get('tfliteModel'):
return TFLiteFormat(self._data.get('tfliteModel'))
return None

@model_format.setter
def model_format(self, model_format):
if not isinstance(model_format, TFLiteFormat):
raise TypeError('Unsupported model format type.')
self._data['tfliteModel'] = model_format.get_json()
return self

def get_json(self):
return self._data


class ModelFormat(object):
"""Abstract base class representing a Model Format such as TFLite."""
def get_json(self):
raise NotImplementedError


class TFLiteFormat(ModelFormat):
"""Model format representing a TFLite model."""
def __init__(self, data=None, model_source=None):
if (data is not None) and isinstance(data, dict):
self._data = data
else:
self._data = {}
if model_source is not None:
# Check for correct base type
if not isinstance(model_source, TFLiteModelSource):
raise TypeError('Model source must be a ModelSource object.')
# Set based on specific sub type
if isinstance(model_source, TFLiteGCSModelSource):
self._data['gcsTfliteUri'] = model_source.get_json()
else:
raise TypeError('Unsupported model source type.')


def __eq__(self, other):
if isinstance(other, self.__class__):
return self._data == other._data # pylint: disable=protected-access
else:
return False

def __ne__(self, other):
return not self.__eq__(other)

@property
def model_source(self):
if self._data.get('gcsTfliteUri'):
return TFLiteGCSModelSource(self._data.get('gcsTfliteUri'))
return None

@model_source.setter
def model_source(self, model_source):
if model_source is not None:
if isinstance(model_source, TFLiteGCSModelSource):
self._data['gcsTfliteUri'] = model_source.get_json()
else:
raise TypeError('Unsupported model source type.')


@property
def size_bytes(self):
return self._data.get('sizeBytes')

def get_json(self):
return self._data


class TFLiteModelSource(object):
"""Abstract base class representing a model source for TFLite format models."""
def get_json(self):
raise NotImplementedError


class TFLiteGCSModelSource(TFLiteModelSource):
"""TFLite model source representing a tflite model file stored in GCS."""
def __init__(self, gcs_tflite_uri):
_validate_gcs_tflite_uri(gcs_tflite_uri)
self._gcs_tflite_uri = gcs_tflite_uri

#TODO(ifielker): define the rest of the Model properties etc
def __eq__(self, other):
if isinstance(other, self.__class__):
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
else:
return False

def __ne__(self, other):
return not self.__eq__(other)

@property
def gcs_tflite_uri(self):
return self._gcs_tflite_uri

@gcs_tflite_uri.setter
def gcs_tflite_uri(self, gcs_tflite_uri):
_validate_gcs_tflite_uri(gcs_tflite_uri)
self._gcs_tflite_uri = gcs_tflite_uri

def get_json(self):
return self._gcs_tflite_uri

#TODO(ifielker): implement from_saved_model etc.

class ListModelsPage(object):
"""Represents a page of models in a firebase project.
Expand Down Expand Up @@ -179,13 +363,55 @@ def __iter__(self):
return self


def _validate_and_parse_name(name):
# The resource name is added automatically from API call responses.
# The only way it could be invalid is if someone tries to
# create a model from a dictionary manually and does it incorrectly.
if not isinstance(name, six.string_types):
raise TypeError('Model resource name must be a string.')
matcher = re.match(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$',
name)
if not matcher:
raise ValueError('Model resource name format is invalid.')
return matcher.group('project_id'), matcher.group('model_id')


def _validate_model_id(model_id):
if not isinstance(model_id, six.string_types):
raise TypeError('Model ID must be a string.')
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id):
raise ValueError('Model ID format is invalid.')


def _validate_display_name(display_name):
if not isinstance(display_name, six.string_types):
raise TypeError('Display name must be a string.')
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', display_name):
raise ValueError('Display name format is invalid.')


def _validate_tags(tags):
if not isinstance(tags, list) or not \
all(isinstance(tag, six.string_types) for tag in tags):
raise TypeError('Tags must be a list of strings.')
if not all(re.match(r'^[A-Za-z0-9_-]{1,60}$', tag) for tag in tags):
raise ValueError('Tag format is invalid.')


def _validate_gcs_tflite_uri(uri):
if not isinstance(uri, six.string_types):
raise TypeError('Gcs TFLite URI must be a string.')
# GCS Bucket naming rules are complex. The regex is not comprehensive.
# See https://cloud.google.com/storage/docs/naming for full details.
if not re.match(r'^gs://[a-z0-9_.-]{3,63}/.+', uri):
raise ValueError('GCS TFLite URI format is invalid.')

def _validate_model_format(model_format):
if model_format is not None:
if not isinstance(model_format, ModelFormat):
raise TypeError('Model format must be a ModelFormat object.')

def _validate_list_filter(list_filter):
if list_filter is not None:
if not isinstance(list_filter, six.string_types):
Expand Down
Loading