Skip to content

Modify Operation Handling to not require a name for Done Operations #371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,28 +810,36 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
"""
if not isinstance(operation, dict):
raise TypeError('Operation must be a dictionary.')
op_name = operation.get('name')
_, model_id = _validate_and_parse_operation_name(op_name)

current_attempt = 0
start_time = datetime.datetime.now()
stop_time = (None if max_time_seconds is None else
start_time + datetime.timedelta(seconds=max_time_seconds))
while wait_for_operation and not operation.get('done'):
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
self._exponential_backoff(current_attempt, stop_time)
operation = self.get_operation(op_name)
current_attempt += 1

if operation.get('done'):
# Operations which are immediately done don't have an operation name
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))

# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()
raise exceptions.UnknownError(message='Internal Error: Malformed Operation.')
else:
op_name = operation.get('name')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Consider moving this else block logic into a helper function. _poll_until_complete(operation) or similar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that might be misleading. We only do the polling if wait_for_operation is true. So we could either only call the _poll_until_complete when wait_for_operation is true and have it only be the while loop and then duplicate the done checking below the while loop in both handle_operation and _poll_until_complete or we could name it something else like _poll_until_complete_if_waiting or something? I actually prefer the way it is now, but I'm open to other suggestions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be _poll_until_complete(operation, wait_for_operation)?

I'm fine with how it is implemented now. Just thought the else block could be a bit smaller for clarity. Your call.

_, model_id = _validate_and_parse_operation_name(op_name)
current_attempt = 0
start_time = datetime.datetime.now()
stop_time = (None if max_time_seconds is None else
start_time + datetime.timedelta(seconds=max_time_seconds))
while wait_for_operation and not operation.get('done'):
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
self._exponential_backoff(current_attempt, stop_time)
operation = self.get_operation(op_name)
current_attempt += 1

if operation.get('done'):
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))

# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()


def create_model(self, model):
Expand Down
49 changes: 13 additions & 36 deletions tests/test_mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,23 +158,21 @@
}

OPERATION_DONE_MODEL_JSON_1 = {
'name': OPERATION_NAME_1,
'done': True,
'response': CREATED_UPDATED_MODEL_JSON_1
}
OPERATION_MALFORMED_JSON_1 = {
'name': OPERATION_NAME_1,
'done': True,
# if done is true then either response or error should be populated
}
OPERATION_MISSING_NAME = {
# Name is required if the operation is not done.
'done': False
}
OPERATION_ERROR_CODE = 400
OPERATION_ERROR_MSG = "Invalid argument"
OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT'
OPERATION_ERROR_JSON_1 = {
'name': OPERATION_NAME_1,
'done': True,
'error': {
'code': OPERATION_ERROR_CODE,
Expand Down Expand Up @@ -609,17 +607,10 @@ def test_operation_error(self):
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)

def test_malformed_operation(self):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = mlkit.create_model(MODEL_1)
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'POST'
assert recorder[0].url == TestCreateModel._url(PROJECT_ID)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1)
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
with pytest.raises(Exception) as excinfo:
mlkit.create_model(MODEL_1)
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')

def test_rpc_error_create(self):
create_recorder = instrument_mlkit_service(
Expand Down Expand Up @@ -708,17 +699,10 @@ def test_operation_error(self):
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)

def test_malformed_operation(self):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = mlkit.update_model(MODEL_1)
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')

def test_rpc_error(self):
create_recorder = instrument_mlkit_service(
Expand Down Expand Up @@ -824,17 +808,10 @@ def test_operation_error(self, publish_function):

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_malformed_operation(self, publish_function):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = publish_function(MODEL_ID_1)
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
with pytest.raises(Exception) as excinfo:
publish_function(MODEL_ID_1)
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_rpc_error(self, publish_function):
Expand Down