Skip to content

Commit 2b1e84a

Browse files
authored
Merge branch 'master' into master
2 parents e2713d3 + b58495e commit 2b1e84a

File tree

4 files changed

+5
-10
lines changed

4 files changed

+5
-10
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
:class:`~sagemaker.estimator.Framework` and
114114
:class:`~sagemaker.estimator.EstimatorBase`.
115115
"""
116+
distribution = renamed_kwargs("distributions", "distribution", distribution, kwargs)
116117
instance_type = renamed_kwargs(
117118
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
118119
)

tests/data/tensorflow_mnist/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def serving_input_fn():
159159

160160
# Train the model
161161
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
162-
x={"x": train_data}, y=train_labels, batch_size=50, num_epochs=None, shuffle=True
162+
x={"x": train_data}, y=train_labels, batch_size=50, num_epochs=None, shuffle=False
163163
)
164164

165165
# Evaluate the model and print results

tests/integ/test_spark_processing.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,19 +302,13 @@ def test_integ_history_server(spark_py_processor, sagemaker_session):
302302
sagemaker_session=sagemaker_session,
303303
)
304304

305+
# sleep 3 seconds to avoid s3 eventual consistency issue
306+
time.sleep(3)
305307
spark_py_processor.start_history_server(spark_event_logs_s3_uri=spark_event_logs_s3_uri)
306308

307309
try:
308310
response = _request_with_retry(HISTORY_SERVER_ENDPOINT)
309311
assert response.status == 200
310-
311-
# spark has redirect behavior, this request verify that page navigation works with redirect
312-
response = _request_with_retry(f"{HISTORY_SERVER_ENDPOINT}{SPARK_APPLICATION_URL_SUFFIX}")
313-
assert response.status == 200
314-
315-
html_content = response.data.decode("utf-8")
316-
assert "Completed Jobs (4)" in html_content
317-
assert "collect at /opt/ml/processing/input/code/test_long_duration.py:32" in html_content
318312
finally:
319313
spark_py_processor.terminate_history_server()
320314

tests/integ/test_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def test_transform_tf_kms_network_isolation(
420420
with open(os.path.join(tmpdir, "tf-batch-output", "data.csv.out")) as f:
421421
result = json.load(f)
422422
assert len(result["predictions"][0]["probabilities"]) == 10
423-
assert result["predictions"][0]["classes"] == 1
423+
assert result["predictions"][0]["classes"] >= 1
424424

425425

426426
def _create_transformer_and_transform_job(

0 commit comments

Comments
 (0)