Skip to content

Commit b1fdf5e

Browse files
committed
test: Adding tests targeting trcomp support for PT 1.12
1 parent c4971a3 commit b1fdf5e

File tree

9 files changed

+1257
-30
lines changed

9 files changed

+1257
-30
lines changed

src/sagemaker/pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@
1616
from sagemaker.pytorch.estimator import PyTorch # noqa: F401
1717
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor # noqa: F401
1818
from sagemaker.pytorch.processing import PyTorchProcessor # noqa: F401
19+
20+
from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig # noqa: F401

src/sagemaker/training_compiler/config.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,12 @@ class TrainingCompilerConfig(object):
2323
"""The SageMaker Training Compiler configuration class."""
2424

2525
DEBUG_PATH = "/opt/ml/output/data/compiler/"
26-
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]
26+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
2727

2828
HP_ENABLE_COMPILER = "sagemaker_training_compiler_enabled"
2929
HP_ENABLE_DEBUG = "sagemaker_training_compiler_debug_mode"
3030

31-
def __init__(
32-
self,
33-
enabled=True,
34-
debug=False,
35-
):
31+
def __init__(self, enabled=True, debug=False):
3632
"""This class initializes a ``TrainingCompilerConfig`` instance.
3733
3834
`Amazon SageMaker Training Compiler
@@ -118,10 +114,7 @@ def _to_hyperparameter_dict(self):
118114
return compiler_config_hyperparameters
119115

120116
@classmethod
121-
def validate(
122-
cls,
123-
estimator,
124-
):
117+
def validate(cls, estimator):
125118
"""Checks if SageMaker Training Compiler is configured correctly.
126119
127120
Args:
@@ -138,19 +131,20 @@ def validate(
138131
warn_msg = (
139132
"Estimator instance_type is a PipelineVariable (%s), "
140133
"which has to be interpreted as one of the "
141-
"[p3, g4dn, p4d, g5] classes in execution time."
134+
"%s classes in execution time."
135+
)
136+
logger.warning(
137+
warn_msg,
138+
type(estimator.instance_type),
139+
str(cls.SUPPORTED_INSTANCE_CLASS_PREFIXES).replace(",", ""),
142140
)
143-
logger.warning(warn_msg, type(estimator.instance_type))
144141
elif estimator.instance_type:
145142
if "local" not in estimator.instance_type:
146143
requested_instance_class = estimator.instance_type.split(".")[
147144
1
148145
] # Expecting ml.class.size
149146
if not any(
150-
[
151-
requested_instance_class.startswith(i)
152-
for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
153-
]
147+
[requested_instance_class == i for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES]
154148
):
155149
error_helper_string = (
156150
"Unsupported Instance class {}."

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"neo_pytorch",
7474
"neo_tensorflow",
7575
"pytorch",
76+
"pytorch_training_compiler",
7677
"ray_pytorch",
7778
"ray_tensorflow",
7879
"sklearn",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
transformers

0 commit comments

Comments
 (0)