Skip to content

Commit d49be8a

Browse files
NihalHarishddavydenko
authored andcommitted
Path String Sanitation (aws#158)
Add path str sanitization to create_trial (do not check interspersing) so that when customer loads trial from local path that has trailing/leading spaces by mistake - they got trimmed.
1 parent b3a21d1 commit d49be8a

File tree

5 files changed

+56
-36
lines changed

5 files changed

+56
-36
lines changed

smdebug/trials/trial.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,10 @@ def _fetch():
145145
num_times_before_warning -= 1
146146
if num_times_before_warning < 0:
147147
self.logger.warning(
148-
"Waiting to read collections files generated by the training job."
149-
"If this has been a while, you might want to check that the "
150-
"trial is pointed at the right path."
148+
f"Waiting to read collections files generated by the training job,"
149+
f"from {self.path}. "
150+
f"If this has been a while, you might want to check that the "
151+
f"trial is pointed at the right path."
151152
)
152153
else:
153154
self.logger.debug(

smdebug/trials/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
def create_trial(path, name=None, **kwargs):
13+
path = path.strip() # Remove any accidental leading/trailing whitespace input by the user
1314
if name is None:
1415
name = os.path.basename(path)
1516
s3, bucket_name, prefix_name = is_s3(path)

tests/analysis/trials/test_has_passed_step_scenarios.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,20 @@
11
# Standard Library
2-
import json
32
import os
43
import shutil
54
import uuid
6-
from pathlib import Path
75

86
# Third Party
97
import pytest
108

119
# First Party
12-
from smdebug.core.collection_manager import CollectionManager
1310
from smdebug.core.config_constants import INCOMPLETE_STEP_WAIT_WINDOW_KEY
14-
from smdebug.core.locations import IndexFileLocationUtils
1511
from smdebug.core.modes import ModeKeys
1612
from smdebug.core.tensor import StepState
1713
from smdebug.exceptions import NoMoreData, StepUnavailable
1814
from smdebug.trials import create_trial
1915

20-
21-
def dummy_trial_creator(trial_dir, num_workers, job_ended):
22-
Path(trial_dir).mkdir(parents=True, exist_ok=True)
23-
cm = CollectionManager()
24-
for i in range(num_workers):
25-
collection_file_name = f"worker_{i}_collections.json"
26-
cm.export(trial_dir, collection_file_name)
27-
if job_ended:
28-
Path(os.path.join(trial_dir, "training_job_end.ts")).touch()
29-
30-
31-
def dummy_step_creator(trial_dir, global_step, mode, mode_step, worker_name):
32-
static_step_data = (
33-
'{"meta": {"mode": "TRAIN", "mode_step": 0, "event_file_name": ""}, '
34-
'"tensor_payload": ['
35-
'{"tensorname": "gradients/dummy:0", "start_idx": 0, "length": 1}'
36-
"]}"
37-
)
38-
39-
step = json.loads(static_step_data)
40-
step["meta"]["mode"] = mode
41-
step["meta"]["mode_step"] = mode_step
42-
43-
index_file_location = IndexFileLocationUtils.get_index_key_for_step(
44-
trial_dir, global_step, worker_name
45-
)
46-
Path(os.path.dirname(index_file_location)).mkdir(parents=True, exist_ok=True)
47-
with open(index_file_location, "w") as f:
48-
json.dump(step, f)
16+
# Local
17+
from ..utils import dummy_step_creator, dummy_trial_creator
4918

5019

5120
@pytest.mark.slow

tests/analysis/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Standard Library
2+
import json
23
import os
4+
from pathlib import Path
35

46
# Third Party
57
import numpy as np
@@ -8,6 +10,7 @@
810
from smdebug.core.access_layer.s3handler import DeleteRequest, S3Handler
911
from smdebug.core.collection_manager import CollectionManager
1012
from smdebug.core.config_constants import DEFAULT_COLLECTIONS_FILE_NAME
13+
from smdebug.core.locations import IndexFileLocationUtils
1114
from smdebug.core.writer import FileWriter
1215

1316

@@ -51,3 +54,33 @@ def check_trial(trial_obj, num_steps, num_tensors):
5154

5255
def delete_s3_prefix(bucket, prefix):
5356
S3Handler.delete_prefix(delete_request=DeleteRequest(Bucket=bucket, Prefix=prefix))
57+
58+
59+
def dummy_trial_creator(trial_dir, num_workers, job_ended):
60+
Path(trial_dir).mkdir(parents=True, exist_ok=True)
61+
cm = CollectionManager()
62+
for i in range(num_workers):
63+
collection_file_name = f"worker_{i}_collections.json"
64+
cm.export(trial_dir, collection_file_name)
65+
if job_ended:
66+
Path(os.path.join(trial_dir, "training_job_end.ts")).touch()
67+
68+
69+
def dummy_step_creator(trial_dir, global_step, mode, mode_step, worker_name):
70+
static_step_data = (
71+
'{"meta": {"mode": "TRAIN", "mode_step": 0, "event_file_name": ""}, '
72+
'"tensor_payload": ['
73+
'{"tensorname": "gradients/dummy:0", "start_idx": 0, "length": 1}'
74+
"]}"
75+
)
76+
77+
step = json.loads(static_step_data)
78+
step["meta"]["mode"] = mode
79+
step["meta"]["mode_step"] = mode_step
80+
81+
index_file_location = IndexFileLocationUtils.get_index_key_for_step(
82+
trial_dir, global_step, worker_name
83+
)
84+
Path(os.path.dirname(index_file_location)).mkdir(parents=True, exist_ok=True)
85+
with open(index_file_location, "w") as f:
86+
json.dump(step, f)

tests/core/test_paths.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,22 @@
1313
from smdebug.core.access_layer.utils import training_has_ended
1414
from smdebug.core.hook_utils import verify_and_get_out_dir
1515
from smdebug.core.utils import SagemakerSimulator, ScriptSimulator
16+
from smdebug.trials import create_trial
17+
18+
# Local
19+
from ..analysis.utils import dummy_trial_creator
20+
21+
22+
def test_whitespace_handling_in_path_str():
23+
_id = str(uuid.uuid4())
24+
path = os.path.join("ts_output/train/", _id)
25+
dummy_trial_creator(trial_dir=path, num_workers=1, job_ended=True)
26+
27+
# Test Leading Whitespace Handling
28+
create_trial(" " + path)
29+
30+
# Test Trailing Whitespace Handling
31+
create_trial(path + " ")
1632

1733

1834
def test_outdir_non_sagemaker():

0 commit comments

Comments
 (0)