Skip to content

fix: security update -> use sha256 instead of md5 for file hashing #4965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions src/sagemaker/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,29 +268,29 @@ def get_config_hash(step: Entity):


def hash_object(obj) -> str:
"""Get the MD5 hash of an object.
"""Get the SHA256 hash of an object.
Args:
obj (dict): The object
Returns:
str: The MD5 hash of the object
str: The SHA256 hash of the object
"""
return hashlib.md5(str(obj).encode()).hexdigest()
return hashlib.sha256(str(obj).encode()).hexdigest()


def hash_file(path: str) -> str:
"""Get the MD5 hash of a file.
"""Get the SHA256 hash of a file.
Args:
path (str): The local path for the file.
Returns:
str: The MD5 hash of the file.
str: The SHA256 hash of the file.
"""
return _hash_file(path, hashlib.md5()).hexdigest()
return _hash_file(path, hashlib.sha256()).hexdigest()


def hash_files_or_dirs(paths: List[str]) -> str:
"""Get the MD5 hash of the contents of a list of files or directories.
"""Get the SHA256 hash of the contents of a list of files or directories.
Hash is changed if:
* input list is changed
Expand All @@ -301,58 +301,58 @@ def hash_files_or_dirs(paths: List[str]) -> str:
Args:
paths: List of file or directory paths
Returns:
str: The MD5 hash of the list of files or directories.
str: The SHA256 hash of the list of files or directories.
"""
md5 = hashlib.md5()
sha256 = hashlib.sha256()
for path in sorted(paths):
md5 = _hash_file_or_dir(path, md5)
return md5.hexdigest()
sha256 = _hash_file_or_dir(path, sha256)
return sha256.hexdigest()


def _hash_file_or_dir(path: str, md5: Hash) -> Hash:
def _hash_file_or_dir(path: str, sha256: Hash) -> Hash:
"""Updates the inputted Hash with the contents of the current path.
Args:
path: path of file or directory
Returns:
str: The MD5 hash of the file or directory
str: The SHA256 hash of the file or directory
"""
if isinstance(path, str) and path.lower().startswith("file://"):
path = unquote(urlparse(path).path)
md5.update(path.encode())
sha256.update(path.encode())
if Path(path).is_dir():
md5 = _hash_dir(path, md5)
sha256 = _hash_dir(path, sha256)
elif Path(path).is_file():
md5 = _hash_file(path, md5)
return md5
sha256 = _hash_file(path, sha256)
return sha256


def _hash_dir(directory: Union[str, Path], md5: Hash) -> Hash:
def _hash_dir(directory: Union[str, Path], sha256: Hash) -> Hash:
"""Updates the inputted Hash with the contents of the current path.
Args:
directory: path of the directory
Returns:
str: The MD5 hash of the directory
str: The SHA256 hash of the directory
"""
if not Path(directory).is_dir():
raise ValueError(str(directory) + " is not a valid directory")
for path in sorted(Path(directory).iterdir()):
md5.update(path.name.encode())
sha256.update(path.name.encode())
if path.is_file():
md5 = _hash_file(path, md5)
sha256 = _hash_file(path, sha256)
elif path.is_dir():
md5 = _hash_dir(path, md5)
return md5
sha256 = _hash_dir(path, sha256)
return sha256


def _hash_file(file: Union[str, Path], md5: Hash) -> Hash:
def _hash_file(file: Union[str, Path], sha256: Hash) -> Hash:
"""Updates the inputted Hash with the contents of the current path.
Args:
file: path of the file
Returns:
str: The MD5 hash of the file
str: The SHA256 hash of the file
"""
if isinstance(file, str) and file.lower().startswith("file://"):
file = unquote(urlparse(file).path)
Expand All @@ -363,8 +363,8 @@ def _hash_file(file: Union[str, Path], md5: Hash) -> Hash:
data = f.read(BUF_SIZE)
if not data:
break
md5.update(data)
return md5
sha256.update(data)
return sha256


def validate_step_args_input(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/workflow/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def test_processing_step_normalizes_args_with_local_code(mock_normalize_args, sc
mock_normalize_args.return_value = [step.inputs, step.outputs]
step.to_request()
mock_normalize_args.assert_called_with(
job_name="MyProcessingStep-3e89f0c7e101c356cbedf27d9d27e9db",
job_name="MyProcessingStep-a22fc59b38f13da26f6a40b18687ba598cf669f74104b793cefd9c63eddf4ac7",
arguments=step.job_arguments,
inputs=step.inputs,
outputs=step.outputs,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/workflow/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def test_hash_file():
with tempfile.NamedTemporaryFile() as tmp:
tmp.write("hashme".encode())
hash = hash_file(tmp.name)
assert hash == "d41d8cd98f00b204e9800998ecf8427e"
assert hash == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"


def test_hash_file_uri():
with tempfile.NamedTemporaryFile() as tmp:
tmp.write("hashme".encode())
hash = hash_file(f"file:///{tmp.name}")
assert hash == "d41d8cd98f00b204e9800998ecf8427e"
assert hash == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"


def test_hash_files_or_dirs_with_file():
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/sagemaker/workflow/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def test_repack_model_step(estimator):
assert hyperparameters["sagemaker_program"] == f'"{REPACK_SCRIPT_LAUNCHER}"'
assert (
hyperparameters["sagemaker_submit_directory"]
== '"s3://my-bucket/MyRepackModelStep-b5ea77f701b47a8d075605497462ccc2/source/sourcedir.tar.gz"'
== '"s3://my-bucket/MyRepackModelStep-717d7bdd388168c27e9ad2938ff0314e35be50b3157cf2498688c7525ea27e1e\
/source/sourcedir.tar.gz"'
)

del request_dict["Arguments"]["HyperParameters"]
Expand Down
Loading