Skip to content

Add samples for the Cloud ML Engine #824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 28, 2017
Merged
189 changes: 189 additions & 0 deletions ml_engine/online_prediction/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This license header looks weird, copy it from elsewhere?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/2016/2017/g

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a shebang

# License, Version 2.0 (the "License"); you may not use this file except in
# compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Examples of using the Cloud ML Engine's online prediction service."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments on authentication. When should this work, or what needs to be true for the script to work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in get_ml_engine_service, can you add a link to the doc page describing how I can download a service account file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: blank line between license and docstring.


# [START import_libraries]
import googleapiclient.discovery
# [END import_libraries]


# [START authenticating]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally show constructing the service in each snippet instead of centralizing it. Every indirection adds cognitive load to the users.

def get_ml_engine_service():
return googleapiclient.discovery.build('ml', 'v1beta1')
# [END authenticating]


# [START predict_json]
def predict_json(project, model, instances, version=None):
"""Send data instances to a deployed model for prediction
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's a sentence then you should put a full stop on it. :)

Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank newline above here.

project: str, project where the Cloud ML Engine Model is deployed.
model: str, model name.
instances: [dict], dictionaries from string keys defined by the model
to data.
version: [optional] str, version of the model to target.
Returns:
A dictionary of prediction results defined by the model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally encourage snippets to be simple enough not to require this, but I understand if that's not reasonable here. If you're going to go full docstring, follow Napoleon style:

Args:
    project (str): ...
    model (str): ...
    instances (Mapping[ str, dict ]): ...
    version (str): optional ...

Returns:
   Mapping [str, ...] : ...

"""
service = get_ml_engine_service()
name = 'projects/{}/models/{}'.format(project, model)
if version is not None:
name += '/versions/{}'.format(version)

response = service.projects().predict(
name=name,
body={'instances': instances}
).execute()

if 'error' in response:
raise RuntimeError(response['error'])

return response['predictions']
# [END predict_json]


# [START predict_tf_records]
def predict_tf_records(project,
model,
example_bytes_list,
key='tfrecord',
version=None):
"""Send data instances to a deployed model for prediction
Args:
project: str, project where the Cloud ML Engine Model is deployed
model: str, model name.
example_bytes_list: [str], Serialized tf.train.Example protos.
version: str, version of the model to target.
Returns:
A dictionary of prediction results defined by the model.
"""
import base64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't import here, import at the top.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you highlight that this import is only necessary for this snippet?

Is that not important?

service = get_ml_engine_service()
name = 'projects/{}/models/{}'.format(project, model)
if version is not None:
name += '/versions/{}'.format(version)

response = service.projects().predict(
name=name,
body={'instances': [
{key: {'b64': base64.b64encode(example_bytes)}}
for example_bytes in example_bytes_list
]}
).execute()
if 'error' in response:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blank new line to separate control statements.

raise RuntimeError(response['error'])

return response['predictions']
# [END predict_tf_records]


# [START census_to_example_bytes]
def census_to_example_bytes(json_instance):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting the file path to be in the census example (in cloudml-samples). Is this file part of the census sample or a more generic 'calling online prediction' sample? If the latter, we need better warnings that this will not work with every model, and we need to describe what the model is expecting.

If this is not part of the census sample, a s/census/json/g is needed.

Sorry if this is a bad question, I not familiar with python-docs-samples

Copy link
Contributor Author

@elibixby elibixby Feb 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we have sort of a hard separation between "things run as part of training" and "code run in your own client to send requests to the prediction service" The former being in cloudml-samples (and in the future tf/garden) and the latter being in python-docs-samples.

I will definitely add some better prose around this in the form of a docstring, and we'll also make it clear in docs.

"""Serialize a JSON example to the bytes of a tf.train.Example.
This method is specific to the signature of the Census example.
Args:
json_instance: dict, representing data to be serialized.
Returns:
A string (as a container for bytes).
"""
import tensorflow as tf
feature_dict = {}
for key, data in json_instance.iteritems():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this is necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to construct a dict of Protos because it's expected by yet another protos. This is just protos being ugly.

if isinstance(data, str) or isinstance(data, unicode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isinstance(data, (str, unicode)), but this will fail on Python 3 because unicode doesn't exist. You should consider using six.string_type or making this the else clause.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is six installed by default in the testing env or will I add it to requirements.txt

feature_dict[key] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[str(data)]))
elif isinstance(data, int) or isinstance(data, float):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif isinstance(data, (int, float)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I just realized this is coded wrong >_<

feature_dict[key] = tf.train.Feature(
float_list=tf.train.FloatList(value=[data]))
return tf.train.Example(
features=tf.train.Features(
feature=feature_dict
)
).SerializeToString()
# [END census_to_example_bytes]


# [START predict_from_files]
def predict_from_files(project,
model,
files,
version=None,
force_tfrecord=False):
import json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't import here.

import itertools

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the local imports? This is not a DF job.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way the necessary inputs will show up in snippets in the docs. The [START foo] and [END foo] blocks indicate a displayable chunk for the docs

instances = (json.loads(line)
for f in files
for line in f.readlines())

# Requests to online prediction
# can have at most 100 instances
args = [instances] * 100

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we making 100 copies of this tuple. This looks wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100 copies of the generator. This is how you batch generators in python (it's weird). But I'm deleting this code anyway in favor of a webapp.

instance_batches = itertools.izip(*args)

results = []

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the input data could need batching (be large), but the results don't need batching? I'm ok with this script not doing batching. Depends on what others say.

for batch in instance_batches:
if force_tfrecord:
example_bytes_list = [
census_to_example_bytes(instance)
for instance in batch
]
results.append(predict_tf_records(
project,
model,
example_bytes_list,
version=version
))
else:
results.append(predict_json(
project,
model,
batch,
version=version
))
return results

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are the results saved or printed?

# [END predict_from_files]


if __name__ == '__main__':
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument(
'input_files',
help='File paths with examples to predict',
nargs='+',
type=os.path.abspath
)
parser.add_argument(
'--project',
help='Project in which the model is deployed',
type=str,
required=True
)
parser.add_argument(
'--model',
help='Model name',
type=str,
required=True
)
parser.add_argument(
'--version',
help='Name of the version.',
type=str
)
parser.add_argument(
'--force-tfrecord',
help='Send predictions as TFRecords rather than raw JSON',
action='store_true',
default=False
)
args = parser.parse_args()
predict_from_files(**args.__dict__)
1 change: 1 addition & 0 deletions ml_engine/online_prediction/resources/test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"age": 25, "workclass": " Private", "education": " 11th", "education_num": 7, "marital_status": " Never-married", "occupation": " Machine-op-inspct", "relationship": " Own-child", "race": " Black", "gender": " Male", "capital_gain": 0, "capital_loss": 0, "hours_per_week": 40, "native_country": " United-States"}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does your batching code really work? Hard to tell with just 1 prediction row