Skip to content

Commit b006d73

Browse files
authored
fix: support content_type in FileSystemInput (#1073)
1 parent a1b63b4 commit b006d73

File tree

4 files changed

+45
-3
lines changed

4 files changed

+45
-3
lines changed

src/sagemaker/inputs.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,12 @@ class FileSystemInput(object):
106106
"""
107107

108108
def __init__(
109-
self, file_system_id, file_system_type, directory_path, file_system_access_mode="ro"
109+
self,
110+
file_system_id,
111+
file_system_type,
112+
directory_path,
113+
file_system_access_mode="ro",
114+
content_type=None,
110115
):
111116
"""Create a new file system input used by an SageMaker training job.
112117
@@ -144,3 +149,6 @@ def __init__(
144149
}
145150
}
146151
}
152+
153+
if content_type:
154+
self.config["ContentType"] = content_type

tests/integ/file_system_input_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ def _ami_id_for_region(sagemaker_session):
107107
def _connect_ec2_instance(ec2_instance):
108108
public_ip_address = ec2_instance.public_ip_address
109109
connected_instance = Connection(
110-
host=public_ip_address, port=22, user="ec2-user", connect_kwargs={"key_filename": KEY_PATH}
110+
host=public_ip_address,
111+
port=22,
112+
user="ec2-user",
113+
connect_kwargs={"key_filename": [KEY_PATH]},
111114
)
112115
return connected_instance
113116

tests/integ/test_tf_efs_fsx.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,12 @@ def test_mnist_efs(efs_fsx_setup, sagemaker_session, cpu_instance_type):
7373
)
7474

7575
file_system_efs_id = efs_fsx_setup["file_system_efs_id"]
76+
content_type = "application/json"
7677
file_system_input = FileSystemInput(
77-
file_system_id=file_system_efs_id, file_system_type="EFS", directory_path=EFS_DIR_PATH
78+
file_system_id=file_system_efs_id,
79+
file_system_type="EFS",
80+
directory_path=EFS_DIR_PATH,
81+
content_type=content_type,
7882
)
7983
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
8084
estimator.fit(inputs=file_system_input, job_name=unique_name_from_base("test-mnist-efs"))

tests/unit/test_inputs.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,33 @@ def test_file_system_input_all_arguments():
113113
assert actual.config == expected
114114

115115

116+
def test_file_system_input_content_type():
117+
file_system_id = "fs-0a48d2a1"
118+
file_system_type = "FSxLustre"
119+
directory_path = "tensorflow"
120+
file_system_access_mode = "rw"
121+
content_type = "application/json"
122+
actual = FileSystemInput(
123+
file_system_id=file_system_id,
124+
file_system_type=file_system_type,
125+
directory_path=directory_path,
126+
file_system_access_mode=file_system_access_mode,
127+
content_type=content_type,
128+
)
129+
expected = {
130+
"DataSource": {
131+
"FileSystemDataSource": {
132+
"FileSystemId": file_system_id,
133+
"FileSystemType": file_system_type,
134+
"DirectoryPath": directory_path,
135+
"FileSystemAccessMode": "rw",
136+
}
137+
},
138+
"ContentType": content_type,
139+
}
140+
assert actual.config == expected
141+
142+
116143
def test_file_system_input_type_invalid():
117144
with pytest.raises(ValueError) as excinfo:
118145
file_system_id = "fs-0a48d2a1"

0 commit comments

Comments
 (0)