Skip to content

Commit f72c632

Browse files
committed
add HF test
1 parent b612952 commit f72c632

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import pytest
17+
import sagemaker.utils
18+
from sagemaker.pytorch import HuggingFace
19+
from tests.integ import timeout
20+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
21+
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
22+
23+
24+
@pytest.mark.skip(
25+
reason="Disabling until the launch of SM Trainium containers" "This test should be re-enabled later."
26+
)
27+
def test_torch_distributed_trn1_pt_mnist(
28+
sagemaker_session,
29+
huggingface_training_latest_version,
30+
huggingface_training_pytorch_latest_version,
31+
huggingface_pytorch_latest_training_py_version,
32+
):
33+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
34+
data_path = os.path.join(DATA_DIR, "huggingface")
35+
estimator = HuggingFace(
36+
py_version=huggingface_pytorch_latest_training_py_version,
37+
entry_point=os.path.join(data_path, "run_glue.py"),
38+
role="SageMakerRole",
39+
transformers_version=huggingface_training_latest_version,
40+
pytorch_version=huggingface_training_pytorch_latest_version,
41+
instance_count=1,
42+
instance_type="g5.12xlarge",
43+
hyperparameters={
44+
"model_name_or_path": "distilbert-base-cased",
45+
"task_name": "wnli",
46+
"do_train": True,
47+
"do_eval": True,
48+
"max_seq_length": 128,
49+
"fp16": True,
50+
"per_device_train_batch_size": 32,
51+
"output_dir": "/opt/ml/model",
52+
},
53+
distribution={"torch_distributed": {"enabled": True}},
54+
sagemaker_session=sagemaker_session,
55+
disable_profiler=True,
56+
)
57+
estimator.fit()

0 commit comments

Comments
 (0)