Skip to content

Add samples for the Cloud ML Engine #824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 28, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ml_engine/online_prediction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
198 changes: 198 additions & 0 deletions ml_engine/online_prediction/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#!/bin/python
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Examples of using the Cloud ML Engine's online prediction service."""
import argparse
import base64
import json

# [START import_libraries]
import googleapiclient.discovery
# [END import_libraries]
import six


# [START predict_json]
def predict_json(project, model, instances, version=None):
"""Send json data to a deployed model for prediction.

Args:
project (str): project where the Cloud ML Engine Model is deployed.
model (str): model name.
instances ([Mapping[str: Any]]): Keys should be the names of Tensors
your deployed model expects as inputs. Values should be datatypes
convertible to Tensors, or (potentially nested) lists of datatypes
convertible to tensors.
version: str, version of the model to target.
Returns:
Mapping[str: any]: dictionary of prediction results defined by the
model.
"""
# Create the ML Engine service object.
# To authenticate set the environment variable
# GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is covered by the client libraries page and should also be covered in the quickstart and a separate auth page. It is necessary to repeat it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeahhhh it was specifically requested but I can remove it.

service = googleapiclient.discovery.build('ml', 'v1beta1')
name = 'projects/{}/models/{}'.format(project, model)

if version is not None:
name += '/versions/{}'.format(version)

response = service.projects().predict(
name=name,
body={'instances': instances}
).execute()

if 'error' in response:
raise RuntimeError(response['error'])

return response['predictions']
# [END predict_json]


# [START predict_tf_records]
def predict_tf_records(project,
model,
example_bytes_list,
version=None):
"""Send protocol buffer data to a deployed model for prediction.

Args:
project (str): project where the Cloud ML Engine Model is deployed.
model (str): model name.
example_bytes_list ([str]): A list of bytestrings representing
serialized tf.train.Example protocol buffers. The contents of this
protocol buffer will change depending on the signature of your
deployed model.
version: str, version of the model to target.
Returns:
Mapping[str: any]: dictionary of prediction results defined by the
model.
"""
service = googleapiclient.discovery.build('ml', 'v1beta1')
name = 'projects/{}/models/{}'.format(project, model)

if version is not None:
name += '/versions/{}'.format(version)

response = service.projects().predict(
name=name,
body={'instances': [
{'b64': base64.b64encode(example_bytes)}
for example_bytes in example_bytes_list
]}
).execute()

if 'error' in response:
raise RuntimeError(response['error'])

return response['predictions']
# [END predict_tf_records]


# [START census_to_example_bytes]
def census_to_example_bytes(json_instance):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting the file path to be in the census example (in cloudml-samples). Is this file part of the census sample or a more generic 'calling online prediction' sample? If the latter, we need better warnings that this will not work with every model, and we need to describe what the model is expecting.

If this is not part of the census sample, a s/census/json/g is needed.

Sorry if this is a bad question, I not familiar with python-docs-samples

Copy link
Contributor Author

@elibixby elibixby Feb 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we have sort of a hard separation between "things run as part of training" and "code run in your own client to send requests to the prediction service" The former being in cloudml-samples (and in the future tf/garden) and the latter being in python-docs-samples.

I will definitely add some better prose around this in the form of a docstring, and we'll also make it clear in docs.

"""Serialize a JSON example to the bytes of a tf.train.Example.
This method is specific to the signature of the Census example.
See: https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
for details.

Args:
json_instance (Mapping[str: Any]): Keys should be the names of Tensors
your deployed model expects to parse using it's tf.FeatureSpec.
Values should be datatypes convertible to Tensors, or (potentially
nested) lists of datatypes convertible to tensors.
Returns:
str: A string as a container for the serialized bytes of
tf.train.Example protocol buffer.
"""
import tensorflow as tf
feature_dict = {}
for key, data in json_instance.iteritems():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this is necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to construct a dict of Protos because it's expected by yet another protos. This is just protos being ugly.

if isinstance(data, six.string_types):
feature_dict[key] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[str(data)]))
elif isinstance(data, float):
feature_dict[key] = tf.train.Feature(
float_list=tf.train.FloatList(value=[data]))
elif isinstance(data, int):
feature_dict[key] = tf.train.Feature(
int64_list=tf.train.Int64List(value=[data]))
return tf.train.Example(
features=tf.train.Features(
feature=feature_dict
)
).SerializeToString()
# [END census_to_example_bytes]


def main(project, model, version=None, force_tfrecord=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of main functions. Will this be included in the docs or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at all in docs

"""Send user input to the prediction service."""
while True:
try:
user_input = json.loads(raw_input("Valid JSON >>>"))
except KeyboardInterrupt:
return

if not isinstance(user_input, list):
user_input = [user_input]
try:
if force_tfrecord:
example_bytes_list = [
census_to_example_bytes(e)
for e in user_input
]
result = predict_tf_records(
project, model, example_bytes_list, version=version)
else:
result = predict_json(
project, model, user_input, version=version)
except RuntimeError as err:
print(str(err))
else:
print(result)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--project',
help='Project in which the model is deployed',
type=str,
required=True
)
parser.add_argument(
'--model',
help='Model name',
type=str,
required=True
)
parser.add_argument(
'--version',
help='Name of the version.',
type=str
)
parser.add_argument(
'--force-tfrecord',
help='Send predictions as TFRecords rather than raw JSON',
action='store_true',
default=False
)
args = parser.parse_args()
main(
args.project,
args.model,
version=args.version,
force_tfrecord=args.force_tfrecord
)
74 changes: 74 additions & 0 deletions ml_engine/online_prediction/predict_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for predict.py ."""

import base64

import pytest

import predict


MODEL = 'census'
VERSION = 'v1'
TF_RECORDS_VERSION = 'v1tfrecord'
PROJECT = 'python-docs-samples-tests'
JSON = {
'age': 25,
'workclass': ' Private',
'education': ' 11th',
'education_num': 7,
'marital_status': ' Never-married',
'occupation': ' Machine-op-inspct',
'relationship': ' Own-child',
'race': ' Black',
'gender': ' Male',
'capital_gain': 0,
'capital_loss': 0,
'hours_per_week': 40,
'native_country': ' United-States'
}
EXPECTED_OUTPUT = {
u'probabilities': [0.9942260384559631, 0.005774002522230148],
u'logits': [-5.148599147796631],
u'classes': 0,
u'logistic': [0.005774001590907574]
}


def test_predict_json():
result = predict.predict_json(
PROJECT, MODEL, [JSON, JSON], version=VERSION)
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result


def test_predict_json_error():
with pytest.raises(RuntimeError):
predict.predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION)


@pytest.mark.slow
def test_census_example_to_bytes():
b = predict.census_to_example_bytes(JSON)
assert base64.b64encode(b) is not None


@pytest.mark.slow
@pytest.mark.xfail('Single placeholder inputs broken in service b/35778449')
def test_predict_tfrecords():
b = predict.census_to_example_bytes(JSON)
result = predict.predict_tfrecords(
PROJECT, MODEL, [b, b], version=TF_RECORDS_VERSION)
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result
1 change: 1 addition & 0 deletions ml_engine/online_prediction/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==1.0.0