-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add samples for the Cloud ML Engine #824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
5dd9b31
42645c7
52e37a5
71daf1f
11f076d
125dda0
c43b65d
ad349a5
7df6c4f
d5a0e18
6bd101a
a241ff2
5ca6773
8c3da38
90bfa5d
2d1726a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs a shebang |
||
# License, Version 2.0 (the "License"); you may not use this file except in | ||
# compliance with the License. You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
# License for the specific language governing permissions and limitations under | ||
# the License. | ||
"""Examples of using the Cloud ML Engine's online prediction service.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comments on authentication. When should this work, or what needs to be true for the script to work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in get_ml_engine_service, can you add a link to the doc page describing how I can download a service account file? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: blank line between license and docstring. |
||
from __future__ import print_function | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't necessary. |
||
# [START import_libraries] | ||
import googleapiclient.discovery | ||
# [END import_libraries] | ||
|
||
|
||
# [START authenticating] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We generally show constructing the service in each snippet instead of centralizing it. Every indirection adds cognitive load to the users. |
||
def get_ml_engine_service(): | ||
"""Create the ML Engine service object. | ||
To authenticate set the environment variable | ||
GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file> | ||
""" | ||
return googleapiclient.discovery.build('ml', 'v1beta1') | ||
# [END authenticating] | ||
|
||
|
||
# [START predict_json] | ||
def predict_json(project, model, instances, version=None): | ||
"""Send data instances to a deployed model for prediction | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's a sentence then you should put a full stop on it. :) |
||
Args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank newline above here. |
||
project: str, project where the Cloud ML Engine Model is deployed. | ||
model: str, model name. | ||
instances: [dict], dictionaries from string keys defined by the model | ||
to data. | ||
version: [optional] str, version of the model to target. | ||
Returns: | ||
A dictionary of prediction results defined by the model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We generally encourage snippets to be simple enough not to require this, but I understand if that's not reasonable here. If you're going to go full docstring, follow Napoleon style:
|
||
""" | ||
service = get_ml_engine_service() | ||
name = 'projects/{}/models/{}'.format(project, model) | ||
if version is not None: | ||
name += '/versions/{}'.format(version) | ||
|
||
response = service.projects().predict( | ||
name=name, | ||
body={'instances': instances} | ||
).execute() | ||
|
||
if 'error' in response: | ||
raise RuntimeError(response['error']) | ||
|
||
return response['predictions'] | ||
# [END predict_json] | ||
|
||
|
||
# [START predict_tf_records] | ||
def predict_tf_records(project, | ||
model, | ||
example_bytes_list, | ||
key='tfrecord', | ||
version=None): | ||
"""Send data instances to a deployed model for prediction | ||
Args: | ||
project: str, project where the Cloud ML Engine Model is deployed | ||
model: str, model name. | ||
example_bytes_list: [str], Serialized tf.train.Example protos. | ||
version: str, version of the model to target. | ||
Returns: | ||
A dictionary of prediction results defined by the model. | ||
""" | ||
import base64 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't import here, import at the top. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do you highlight that this import is only necessary for this snippet? Is that not important? |
||
service = get_ml_engine_service() | ||
name = 'projects/{}/models/{}'.format(project, model) | ||
if version is not None: | ||
name += '/versions/{}'.format(version) | ||
|
||
response = service.projects().predict( | ||
name=name, | ||
body={'instances': [ | ||
{key: {'b64': base64.b64encode(example_bytes)}} | ||
for example_bytes in example_bytes_list | ||
]} | ||
).execute() | ||
if 'error' in response: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Blank new line to separate control statements. |
||
raise RuntimeError(response['error']) | ||
|
||
return response['predictions'] | ||
# [END predict_tf_records] | ||
|
||
|
||
# [START census_to_example_bytes] | ||
def census_to_example_bytes(json_instance): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was expecting the file path to be in the census example (in cloudml-samples). Is this file part of the census sample or a more generic 'calling online prediction' sample? If the latter, we need better warnings that this will not work with every model, and we need to describe what the model is expecting. If this is not part of the census sample, a s/census/json/g is needed. Sorry if this is a bad question, I not familiar with python-docs-samples There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking we have sort of a hard separation between "things run as part of training" and "code run in your own client to send requests to the prediction service" The former being in cloudml-samples (and in the future tf/garden) and the latter being in python-docs-samples. I will definitely add some better prose around this in the form of a docstring, and we'll also make it clear in docs. |
||
"""Serialize a JSON example to the bytes of a tf.train.Example. | ||
This method is specific to the signature of the Census example. | ||
Args: | ||
json_instance: dict, representing data to be serialized. | ||
Returns: | ||
A string (as a container for bytes). | ||
""" | ||
import tensorflow as tf | ||
feature_dict = {} | ||
for key, data in json_instance.iteritems(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why this is necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to construct a dict of Protos because it's expected by yet another protos. This is just protos being ugly. |
||
if isinstance(data, str) or isinstance(data, unicode): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is |
||
feature_dict[key] = tf.train.Feature( | ||
bytes_list=tf.train.BytesList(value=[str(data)])) | ||
elif isinstance(data, int) or isinstance(data, float): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cool! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I just realized this is coded wrong >_< |
||
feature_dict[key] = tf.train.Feature( | ||
float_list=tf.train.FloatList(value=[data])) | ||
return tf.train.Example( | ||
features=tf.train.Features( | ||
feature=feature_dict | ||
) | ||
).SerializeToString() | ||
# [END census_to_example_bytes] | ||
|
||
|
||
def main(project, model, version=None, force_tfrecord=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a huge fan of main functions. Will this be included in the docs or not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not at all in docs |
||
"""Send user input to the prediction service.""" | ||
import json | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't import here. |
||
while True: | ||
try: | ||
user_input = json.loads(raw_input("Valid JSON >>>")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a fan of raw-input (cannot re-run this quickly, and typing valid json is a pain). But this allows interactive input and "python predict.py < my_data.json". Maybe add file-level comments on these two ways of using this script? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the main reason I wanted to do it this way, is we already have a solution for batch prediction (via the API) and a 100 request limit seems really bad if the use-case we are highlighting is predicting from files. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where do the users find out what kind of json to send here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It depends on their model. This snippet will be part of a docs page that is attempting to explain just that. This will be at the end "now that you know what the prediction service does, here's how you call it". |
||
except KeyboardInterrupt: | ||
return | ||
|
||
if not isinstance(user_input, list): | ||
user_input = [user_input] | ||
try: | ||
if force_tfrecord: | ||
example_bytes_list = [ | ||
census_to_example_bytes(e) | ||
for e in user_input | ||
] | ||
result = predict_tf_records( | ||
project, model, example_bytes_list, version=version) | ||
else: | ||
result = predict_json( | ||
project, model, user_input, version=version) | ||
except RuntimeError as err: | ||
print(str(err)) | ||
else: | ||
print(result) | ||
|
||
|
||
if __name__ == '__main__': | ||
import argparse | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--project', | ||
help='Project in which the model is deployed', | ||
type=str, | ||
required=True | ||
) | ||
parser.add_argument( | ||
'--model', | ||
help='Model name', | ||
type=str, | ||
required=True | ||
) | ||
parser.add_argument( | ||
'--version', | ||
help='Name of the version.', | ||
type=str | ||
) | ||
parser.add_argument( | ||
'--force-tfrecord', | ||
help='Send predictions as TFRecords rather than raw JSON', | ||
action='store_true', | ||
default=False | ||
) | ||
args = parser.parse_args() | ||
main(**args.__dict__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please explicitly pass args instead of using |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2017, also, these headers still seem different from the ones in the rest of the repo. |
||
# License, Version 2.0 (the "License"); you may not use this file except in | ||
# compliance with the License. You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
# License for the specific language governing permissions and limitations under | ||
# the License. | ||
"""Tests for predict.py .""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank newline both above and below this. |
||
import base64 | ||
|
||
import pytest | ||
|
||
from predict import census_to_example_bytes, predict_json | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just import |
||
|
||
|
||
MODEL = 'census' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm starting to think you want these files in GoogleCloudPlatform/cloudml-samples/census/something |
||
VERSION = 'v1' | ||
PROJECT = 'python-docs-samples-tests' | ||
JSON = { | ||
'age': 25, | ||
'workclass': ' Private', | ||
'education': ' 11th', | ||
'education_num': 7, | ||
'marital_status': ' Never-married', | ||
'occupation': ' Machine-op-inspct', | ||
'relationship': ' Own-child', | ||
'race': ' Black', | ||
'gender': ' Male', | ||
'capital_gain': 0, | ||
'capital_loss': 0, | ||
'hours_per_week': 40, | ||
'native_country': ' United-States' | ||
} | ||
EXPECTED_OUTPUT = { | ||
u'probabilities': [0.9942260384559631, 0.005774002522230148], | ||
u'logits': [-5.148599147796631], | ||
u'classes': 0, | ||
u'logistic': [0.005774001590907574] | ||
} | ||
|
||
|
||
def test_predict_json(): | ||
result = predict_json(PROJECT, MODEL, [JSON, JSON], version=VERSION) | ||
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result | ||
|
||
|
||
def test_predict_json_error(): | ||
with pytest.raises(RuntimeError): | ||
predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION) | ||
|
||
|
||
def test_census_example_to_bytes(): | ||
b = census_to_example_bytes(JSON) | ||
assert base64.b64encode(b) is not None | ||
|
||
|
||
def test_predict_tfrecord(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not write a real test and mark it with |
||
# Using the same model for TFRecords and | ||
# JSON is currently broken. | ||
# TODO(elibixby) when b/35742966 is fixed add | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
tensorflow>=1.0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use ranges, pin the version and dpebot will handle updating it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This license header looks weird, copy it from elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/2016/2017/g