Skip to content

Add samples for the Cloud ML Engine #824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 28, 2017
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions ml_engine/online_prediction/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This license header looks weird, copy it from elsewhere?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/2016/2017/g

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a shebang

# License, Version 2.0 (the "License"); you may not use this file except in
# compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Examples of using the Cloud ML Engine's online prediction service."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments on authentication. When should this work, or what needs to be true for the script to work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in get_ml_engine_service, can you add a link to the doc page describing how I can download a service account file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: blank line between license and docstring.

from __future__ import print_function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't necessary.

# [START import_libraries]
import googleapiclient.discovery
# [END import_libraries]


# [START authenticating]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally show constructing the service in each snippet instead of centralizing it. Every indirection adds cognitive load to the users.

def get_ml_engine_service():
"""Create the ML Engine service object.
To authenticate set the environment variable
GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
"""
return googleapiclient.discovery.build('ml', 'v1beta1')
# [END authenticating]


# [START predict_json]
def predict_json(project, model, instances, version=None):
"""Send data instances to a deployed model for prediction
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's a sentence then you should put a full stop on it. :)

Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank newline above here.

project: str, project where the Cloud ML Engine Model is deployed.
model: str, model name.
instances: [dict], dictionaries from string keys defined by the model
to data.
version: [optional] str, version of the model to target.
Returns:
A dictionary of prediction results defined by the model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally encourage snippets to be simple enough not to require this, but I understand if that's not reasonable here. If you're going to go full docstring, follow Napoleon style:

Args:
    project (str): ...
    model (str): ...
    instances (Mapping[ str, dict ]): ...
    version (str): optional ...

Returns:
   Mapping [str, ...] : ...

"""
service = get_ml_engine_service()
name = 'projects/{}/models/{}'.format(project, model)
if version is not None:
name += '/versions/{}'.format(version)

response = service.projects().predict(
name=name,
body={'instances': instances}
).execute()

if 'error' in response:
raise RuntimeError(response['error'])

return response['predictions']
# [END predict_json]


# [START predict_tf_records]
def predict_tf_records(project,
model,
example_bytes_list,
key='tfrecord',
version=None):
"""Send data instances to a deployed model for prediction
Args:
project: str, project where the Cloud ML Engine Model is deployed
model: str, model name.
example_bytes_list: [str], Serialized tf.train.Example protos.
version: str, version of the model to target.
Returns:
A dictionary of prediction results defined by the model.
"""
import base64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't import here, import at the top.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you highlight that this import is only necessary for this snippet?

Is that not important?

service = get_ml_engine_service()
name = 'projects/{}/models/{}'.format(project, model)
if version is not None:
name += '/versions/{}'.format(version)

response = service.projects().predict(
name=name,
body={'instances': [
{key: {'b64': base64.b64encode(example_bytes)}}
for example_bytes in example_bytes_list
]}
).execute()
if 'error' in response:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blank new line to separate control statements.

raise RuntimeError(response['error'])

return response['predictions']
# [END predict_tf_records]


# [START census_to_example_bytes]
def census_to_example_bytes(json_instance):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting the file path to be in the census example (in cloudml-samples). Is this file part of the census sample or a more generic 'calling online prediction' sample? If the latter, we need better warnings that this will not work with every model, and we need to describe what the model is expecting.

If this is not part of the census sample, a s/census/json/g is needed.

Sorry if this is a bad question, I not familiar with python-docs-samples

Copy link
Contributor Author

@elibixby elibixby Feb 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we have sort of a hard separation between "things run as part of training" and "code run in your own client to send requests to the prediction service" The former being in cloudml-samples (and in the future tf/garden) and the latter being in python-docs-samples.

I will definitely add some better prose around this in the form of a docstring, and we'll also make it clear in docs.

"""Serialize a JSON example to the bytes of a tf.train.Example.
This method is specific to the signature of the Census example.
Args:
json_instance: dict, representing data to be serialized.
Returns:
A string (as a container for bytes).
"""
import tensorflow as tf
feature_dict = {}
for key, data in json_instance.iteritems():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this is necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to construct a dict of Protos because it's expected by yet another protos. This is just protos being ugly.

if isinstance(data, str) or isinstance(data, unicode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isinstance(data, (str, unicode)), but this will fail on Python 3 because unicode doesn't exist. You should consider using six.string_type or making this the else clause.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is six installed by default in the testing env or will I add it to requirements.txt

feature_dict[key] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[str(data)]))
elif isinstance(data, int) or isinstance(data, float):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif isinstance(data, (int, float)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I just realized this is coded wrong >_<

feature_dict[key] = tf.train.Feature(
float_list=tf.train.FloatList(value=[data]))
return tf.train.Example(
features=tf.train.Features(
feature=feature_dict
)
).SerializeToString()
# [END census_to_example_bytes]


def main(project, model, version=None, force_tfrecord=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of main functions. Will this be included in the docs or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at all in docs

"""Send user input to the prediction service."""
import json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't import here.

while True:
try:
user_input = json.loads(raw_input("Valid JSON >>>"))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of raw-input (cannot re-run this quickly, and typing valid json is a pain). But this allows interactive input and "python predict.py < my_data.json". Maybe add file-level comments on these two ways of using this script?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the main reason I wanted to do it this way, is we already have a solution for batch prediction (via the API) and a 100 request limit seems really bad if the use-case we are highlighting is predicting from files.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do the users find out what kind of json to send here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on their model. This snippet will be part of a docs page that is attempting to explain just that. This will be at the end "now that you know what the prediction service does, here's how you call it".

except KeyboardInterrupt:
return

if not isinstance(user_input, list):
user_input = [user_input]
try:
if force_tfrecord:
example_bytes_list = [
census_to_example_bytes(e)
for e in user_input
]
result = predict_tf_records(
project, model, example_bytes_list, version=version)
else:
result = predict_json(
project, model, user_input, version=version)
except RuntimeError as err:
print(str(err))
else:
print(result)


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--project',
help='Project in which the model is deployed',
type=str,
required=True
)
parser.add_argument(
'--model',
help='Model name',
type=str,
required=True
)
parser.add_argument(
'--version',
help='Name of the version.',
type=str
)
parser.add_argument(
'--force-tfrecord',
help='Send predictions as TFRecords rather than raw JSON',
action='store_true',
default=False
)
args = parser.parse_args()
main(**args.__dict__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please explicitly pass args instead of using **args.__dict__.

64 changes: 64 additions & 0 deletions ml_engine/online_prediction/predict_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2017, also, these headers still seem different from the ones in the rest of the repo.

# License, Version 2.0 (the "License"); you may not use this file except in
# compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests for predict.py ."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank newline both above and below this.

import base64

import pytest

from predict import census_to_example_bytes, predict_json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just import predict, please don't import individual members.



MODEL = 'census'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm starting to think you want these files in GoogleCloudPlatform/cloudml-samples/census/something

VERSION = 'v1'
PROJECT = 'python-docs-samples-tests'
JSON = {
'age': 25,
'workclass': ' Private',
'education': ' 11th',
'education_num': 7,
'marital_status': ' Never-married',
'occupation': ' Machine-op-inspct',
'relationship': ' Own-child',
'race': ' Black',
'gender': ' Male',
'capital_gain': 0,
'capital_loss': 0,
'hours_per_week': 40,
'native_country': ' United-States'
}
EXPECTED_OUTPUT = {
u'probabilities': [0.9942260384559631, 0.005774002522230148],
u'logits': [-5.148599147796631],
u'classes': 0,
u'logistic': [0.005774001590907574]
}


def test_predict_json():
result = predict_json(PROJECT, MODEL, [JSON, JSON], version=VERSION)
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result


def test_predict_json_error():
with pytest.raises(RuntimeError):
predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION)


def test_census_example_to_bytes():
b = census_to_example_bytes(JSON)
assert base64.b64encode(b) is not None


def test_predict_tfrecord():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not write a real test and mark it with pytest.mark.xfail('reason')?

# Using the same model for TFRecords and
# JSON is currently broken.
# TODO(elibixby) when b/35742966 is fixed add
pass
1 change: 1 addition & 0 deletions ml_engine/online_prediction/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow>=1.0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use ranges, pin the version and dpebot will handle updating it.