Skip to content

Commit debd07b

Browse files
committed
Fixing logic errors in attach
1 parent defb20c commit debd07b

File tree

4 files changed

+4
-3
lines changed

4 files changed

+4
-3
lines changed

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def framework_name_from_image(image_uri):
493493
# We must support both the legacy and current image name format.
494494
name_pattern = re.compile(
495495
r"""^(?:sagemaker(?:-rl)?-)?
496-
(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost
496+
(tensorflow|mxnet|chainer|pytorch|pytorch-trcomp|scikit-learn|xgboost
497497
|huggingface-tensorflow|huggingface-pytorch
498498
|huggingface-tensorflow-trcomp|huggingface-pytorch-trcomp)(?:-)?
499499
(scriptmode|training)?

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
423423
)
424424
image_uri = init_params.pop("image_uri")
425425
framework, py_version, tag, _ = framework_name_from_image(image_uri)
426+
framework = framework.split("-")[0]
426427

427428
if tag is None:
428429
framework_version = None

src/sagemaker/pytorch/training_compiler/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
class TrainingCompilerConfig(BaseConfig):
2727
"""The SageMaker Training Compiler configuration class."""
2828

29-
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"]
29+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
3030
SUPPORTED_INSTANCE_TYPES_WITH_EFA = [
3131
"ml.g4dn.8xlarge",
3232
"ml.g4dn.12xlarge",

tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES
2929

3030

31-
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
31+
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "..", "data")
3232
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
3333
SERVING_SCRIPT_FILE = "another_dummy_script.py"
3434
MODEL_DATA = "s3://some/data.tar.gz"

0 commit comments

Comments
 (0)