Skip to content

Commit 741f687

Browse files
author
Chuyang Deng
committed
update warning message
1 parent 69d06ad commit 741f687

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

src/sagemaker/fw_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
"Please set the argument \"py_version='py3'\" to use the Python 3 {framework} image."
5050
)
5151
PARAMETER_SERVER_MULTI_GPU_WARNING = (
52-
"You have selected a multi-GPU training instance type. "
53-
"You have also enabled parameter server for distributed training. "
52+
"If you have selected a multi-GPU training instance type, "
53+
"and have also enabled parameter server for distributed training. "
5454
"Distributed training with the default parameter server configuration will not "
5555
"fully leverage all GPU cores; the parameter server will be configured to run "
5656
"only one worker per host regardless of the number of GPUs."
@@ -617,9 +617,9 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
617617
return
618618

619619
is_multi_gpu_instance = (
620-
training_instance_type.split(".")[1].startswith("p")
621-
and training_instance_type not in SINGLE_GPU_INSTANCE_TYPES
622-
)
620+
training_instance_type == "local_gpu"
621+
or training_instance_type.split(".")[1].startswith("p")
622+
) and training_instance_type not in SINGLE_GPU_INSTANCE_TYPES
623623

624624
ps_enabled = "parameter_server" in distributions and distributions["parameter_server"].get(
625625
"enabled", False

tests/unit/test_fw_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,3 +1272,13 @@ def test_warn_if_parameter_server_with_multi_gpu(caplog):
12721272
training_instance_type=train_instance_type, distributions=distributions
12731273
)
12741274
assert fw_utils.PARAMETER_SERVER_MULTI_GPU_WARNING in caplog.text
1275+
1276+
1277+
def test_war_if_parameter_server_with_multi_gpu(caplog):
1278+
train_instance_type = "local_gpu"
1279+
distributions = {"parameter_server": {"enabled": True}}
1280+
1281+
fw_utils.warn_if_parameter_server_with_multi_gpu(
1282+
training_instance_type=train_instance_type, distributions=distributions
1283+
)
1284+
assert fw_utils.PARAMETER_SERVER_MULTI_GPU_WARNING in caplog.text

0 commit comments

Comments
 (0)