Skip to content

Firebase ML Kit Create Model API implementation #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 11, 2019
Merged
21 changes: 21 additions & 0 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ def handle_platform_error_from_requests(error, handle_func=None):
return exc if exc else _handle_func_requests(error, message, error_dict)


def handle_operation_error(error):
"""Constructs a ``FirebaseError`` from the given operation error.

Args:
error: An error returned by a long running operation.

Returns:
FirebaseError: A ``FirebaseError`` that can be raised to the user code.
"""
if not isinstance(error, dict):
return exceptions.UnknownError(
message='Unknown error while making a remote service call: {0}'.format(error),
cause=error)

status_code = error.get('code')
message = error.get('message')
error_code = _http_status_to_error_code(status_code)
err_type = _error_code_to_exception_type(error_code)
return err_type(message=message)


def _handle_func_requests(error, message, error_dict):
"""Constructs a ``FirebaseError`` from the given GCP error.

Expand Down
119 changes: 119 additions & 0 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
import datetime
import numbers
import re
import time
import requests
import six


from firebase_admin import _http_client
from firebase_admin import _utils
from firebase_admin import exceptions


_MLKIT_ATTRIBUTE = '_mlkit'
Expand All @@ -36,6 +39,8 @@
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
r'^operations/project/[^/]+/model/[A-Za-z0-9_-]{1,60}/operation/[^/]+$')


def _get_mlkit_service(app):
Expand All @@ -53,18 +58,60 @@ def _get_mlkit_service(app):
return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService)


def create_model(model, app=None):
"""Creates a model in Firebase ML Kit.

Args:
model: An mlkit.Model to create.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The model that was created in Firebase ML Kit.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.create_model(model))


def get_model(model_id, app=None):
"""Gets a model from Firebase ML Kit.

Args:
model_id: The id of the model to get.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The requested model.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.get_model(model_id))


def list_models(list_filter=None, page_size=None, page_token=None, app=None):
"""Lists models from Firebase ML Kit.

Args:
list_filter: a list filter string such as "tags:'tag_1'". None will return all models.
page_size: A number between 1 and 100 inclusive that specifies the maximum
number of models to return per page. None for default.
page_token: A next page token returned from a previous page of results. None
for first page of results.
app: A Firebase app instance (or None to use the default app).

Returns:
ListModelsPage: A (filtered) list of models.
"""
mlkit_service = _get_mlkit_service(app)
return ListModelsPage(
mlkit_service.list_models, list_filter, page_size, page_token)


def delete_model(model_id, app=None):
"""Deletes a model from Firebase ML Kit.

Args:
model_id: The id of the model you wish to delete.
app: A Firebase app instance (or None to use the default app).
"""
mlkit_service = _get_mlkit_service(app)
mlkit_service.delete_model(model_id)

Expand Down Expand Up @@ -390,11 +437,23 @@ def _validate_and_parse_name(name):
return matcher.group('project_id'), matcher.group('model_id')


def _validate_model(model):
if not isinstance(model, Model):
raise TypeError('Model must be an mlkit.Model.')
if not model.display_name:
raise ValueError('Model must have a display name.')


def _validate_model_id(model_id):
if not _MODEL_ID_PATTERN.match(model_id):
raise ValueError('Model ID format is invalid.')


def _validate_operation_name(op_name):
if not _OPERATION_NAME_PATTERN.match(op_name):
raise ValueError('Operation name format is invalid.')


def _validate_display_name(display_name):
if not _DISPLAY_NAME_PATTERN.match(display_name):
raise ValueError('Display name format is invalid.')
Expand Down Expand Up @@ -448,6 +507,11 @@ class _MLKitService(object):
"""Firebase MLKit service."""

PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'
OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/'
OPERATION_POLL_DELAY_SECONDS = 30
MAX_POLLING_ATTEMPTS = 10
POLL_EXPONENTIAL_BACKOFF_FACTOR = 2
POLL_BASE_WAIT_TIME_SECONDS = 1

def __init__(self, app):
project_id = app.project_id
Expand All @@ -459,6 +523,61 @@ def __init__(self, app):
self._client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=self._project_url)
self._operation_client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=_MLKitService.OPERATION_URL)

def get_operation(self, op_name):
_validate_operation_name(op_name)
try:
return self._operation_client.body('get', url=op_name)
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def handle_operation(self, operation):
"""Handles long running operations.

Args:
operation: The operation to handle.

Returns:
dict: A dictionary of the returned model properties.

Raises:
TypeError: if the operation is not a dictionary.
ValueError: If the operation is malformed.
"""
if not isinstance(operation, dict):
raise TypeError('Operation must be a dictionary.')
op_name = operation.get('name')
_validate_operation_name(op_name)

for current_attempt in range(_MLKitService.MAX_POLLING_ATTEMPTS):
if operation.get('done'):
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))
else:
# A 'done' operation must have either a response or an error.
raise ValueError('Operation is malformed.')
else:
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
delay_factor = pow(
_MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt)
wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS
time.sleep(wait_time_seconds)
operation = self.get_operation(op_name)
raise exceptions.DeadlineExceededError('Polling deadline exceeded.')

def create_model(self, model):
_validate_model(model)
try:
return self.handle_operation(
self._client.body('post', url='models', json=model.as_dict()))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def get_model(self, model_id):
_validate_model_id(model_id)
Expand Down
Loading