-
Notifications
You must be signed in to change notification settings - Fork 339
Firebase ML Kit Get Model API implementation #326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
723ef2d
976bf1e
b47e70b
95e7a7a
aa263ac
757cbc8
b688250
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,8 +18,91 @@ | |
deleting, publishing and unpublishing Firebase ML Kit models. | ||
""" | ||
|
||
import re | ||
import requests | ||
import six | ||
|
||
import firebase_admin | ||
from firebase_admin import _http_client | ||
from firebase_admin import _utils | ||
from firebase_admin import exceptions | ||
|
||
_MLKIT_ATTRIBUTE = '_mlkit' | ||
|
||
def _get_mlkit_service(app): | ||
""" Returns an _MLKitService instance for an App. | ||
|
||
Args: | ||
app: A Firebase App instance (or None to use the default App). | ||
|
||
Returns: | ||
_MLKitService: An _MLKitService for the specified App instance. | ||
|
||
Raises: | ||
ValueError: If the app argument is invalid. | ||
""" | ||
return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService) | ||
|
||
def get_model(model_id, app=None): | ||
mlkit_service = _get_mlkit_service(app) | ||
return Model(mlkit_service.get_model(model_id)) | ||
|
||
class Model(object): | ||
"""A Firebase ML Kit Model object.""" | ||
def __init__(self, data): | ||
"""Created from a data dictionary.""" | ||
self._data = data | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, self.__class__): | ||
return self._data == other._data | ||
else: | ||
return False | ||
|
||
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
||
#TODO(ifielker): define the Model properties etc | ||
|
||
class _MLKitService(object): | ||
"""Firebase MLKit service.""" | ||
|
||
BASE_URL = 'https://mlkit.googleapis.com' | ||
PROJECT_URL = 'https://mlkit.googleapis.com/projects/{0}/' | ||
BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't seem to be used in this module. Perhaps move to the tests module, where it's being referenced? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' | ||
|
||
def __init__(self, app): | ||
project_id = app.project_id | ||
if not project_id: | ||
raise ValueError( | ||
'Project ID is required to access MLKit service. Either set the ' | ||
'projectId option, or use service account credentials.') | ||
self._project_url = _MLKitService.PROJECT_URL.format(project_id) | ||
self._client = _http_client.JsonHttpClient(credential=app.credential.get_credential()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
def _request(self, method, urlpath, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be just appending the path to the base/project URL. I'd suggest passing the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
"""Makes an HTTP call using the Python requests library. | ||
|
||
Args: | ||
method: HTTP method name as a string (e.g. get, post, patch, delete). | ||
urlpath: URL path to the endpoint. This will be appended to the | ||
server's base project URL. | ||
kwargs: An additional set of keyword arguments to be passed into requests | ||
API (e.g. json, params) | ||
|
||
Returns: | ||
dict: The parsed JSON response. | ||
""" | ||
return self._client.body(method, url=self._project_url + urlpath, **kwargs) | ||
|
||
def get_model(self, model_id): | ||
if not model_id: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be removed. If model_id is None, your type check will catch it. If it's empty your regex check will catch it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
raise ValueError('Model Id is required for GetModel.') | ||
if not isinstance(model_id, six.string_types): | ||
raise TypeError('Model Id must be a string.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/Id/ID There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): | ||
raise ValueError('Model Id format is invalid.') | ||
try: | ||
return self._request('get', 'models/{0}'.format(model_id)) | ||
except requests.exceptions.RequestException as error: | ||
raise _utils.handle_requests_error(error) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is a OnePlatform API, it's better to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# Copyright 2019 Google Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Test cases for the firebase_admin.mlkit module.""" | ||
|
||
import json | ||
import pytest | ||
import six | ||
|
||
import firebase_admin | ||
from firebase_admin import exceptions | ||
from firebase_admin import mlkit | ||
from tests import testutils | ||
|
||
PROJECT_ID = 'myProject1' | ||
MODEL_ID_1 = 'modelId1' | ||
MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) | ||
DISPLAY_NAME_1 = 'displayName1' | ||
MODEL_JSON_1 = { | ||
'name': MODEL_NAME_1, | ||
'displayName': DISPLAY_NAME_1 | ||
} | ||
MODEL_1 = mlkit.Model(MODEL_JSON_1) | ||
_DEFAULT_RESPONSE = json.dumps(MODEL_JSON_1) | ||
|
||
ERROR_CODE = 404 | ||
ERROR_MSG = 'The resource was not found' | ||
ERROR_STATUS = 'NOT_FOUND' | ||
ERROR_JSON = { | ||
'error': { | ||
'code': ERROR_CODE, | ||
'message': ERROR_MSG, | ||
'status': ERROR_STATUS | ||
} | ||
} | ||
_ERROR_RESPONSE = json.dumps(ERROR_JSON) | ||
|
||
|
||
class TestGetModel(object): | ||
"""Tests mlkit.get_model.""" | ||
@classmethod | ||
def setup_class(cls): | ||
cred = testutils.MockCredential() | ||
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) | ||
|
||
@classmethod | ||
def teardown_class(cls): | ||
testutils.cleanup_apps() | ||
|
||
@staticmethod | ||
def check_error(err, errType, msg): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/errType/err_type There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
assert isinstance(err, errType) | ||
assert str(err) == msg | ||
|
||
@staticmethod | ||
def check_firebase_error(err, code, status): | ||
assert isinstance(err, exceptions.FirebaseError) | ||
assert err.code == code | ||
assert err.http_response is not None | ||
assert err.http_response.status_code == status | ||
|
||
def _get_url(self, project_id, model_id): | ||
return mlkit._MLKitService.BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) | ||
|
||
def _instrument_mlkit_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): | ||
if not app: | ||
app = firebase_admin.get_app() | ||
mlkit_service = mlkit._get_mlkit_service(app) | ||
recorder = [] | ||
mlkit_service._client.session.mount( | ||
'https://mlkit.googleapis.com', | ||
testutils.MockAdapter(payload, status, recorder) | ||
) | ||
return mlkit_service, recorder | ||
|
||
def test_get_model(self): | ||
_, recorder = self._instrument_mlkit_service() | ||
model = mlkit.get_model(MODEL_ID_1) | ||
assert len(recorder) == 1 | ||
assert recorder[0].method == 'GET' | ||
assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1) | ||
assert model == MODEL_1 | ||
assert model._data['name'] == MODEL_NAME_1 | ||
assert model._data['displayName'] == DISPLAY_NAME_1 | ||
|
||
def test_get_model_validation_errors(self): | ||
_, recorder = self._instrument_mlkit_service() | ||
#Empty model-id | ||
with pytest.raises(ValueError) as err: | ||
mlkit.get_model('') | ||
self.check_error(err.value, ValueError, 'Model Id is required for GetModel.') | ||
|
||
#Wrong type | ||
with pytest.raises(TypeError) as err: | ||
mlkit.get_model(12345) | ||
self.check_error(err.value, TypeError, 'Model Id must be a string.') | ||
|
||
#Invalid characters | ||
with pytest.raises(ValueError) as err: | ||
mlkit.get_model('&_*#@:/?') | ||
self.check_error(err.value, ValueError, 'Model Id format is invalid.') | ||
|
||
def test_get_model_error(self): | ||
_, recorder = self._instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) | ||
with pytest.raises(exceptions.NotFoundError) as err: | ||
mlkit.get_model(MODEL_ID_1) | ||
self.check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also assert the error message. If the OP error got parsed correctly, error message gets set to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
assert len(recorder) == 1 | ||
assert recorder[0].method == 'GET' | ||
assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1) | ||
|
||
def test_no_project_id(self): | ||
def evaluate(): | ||
app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') | ||
with pytest.raises(ValueError): | ||
mlkit.get_model(MODEL_ID_1, app) | ||
testutils.run_without_project_id(evaluate) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove redundant empty lines. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.