Skip to content

fix: add fixes for tarfile extractall functionality PEP-721 #4441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 23, 2024
4 changes: 3 additions & 1 deletion src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import sagemaker.local.data
import sagemaker.local.utils
import sagemaker.utils
from sagemaker.utils import check_tarfile_data_filter_attribute

CONTAINER_PREFIX = "algo"
STUDIO_HOST_NAME = "sagemaker-local"
Expand Down Expand Up @@ -686,7 +687,8 @@ def _prepare_serving_volumes(self, model_location):
for filename in model_data_source.get_file_list():
if tarfile.is_tarfile(filename):
with tarfile.open(filename) as tar:
tar.extractall(path=model_data_source.get_root_dir())
check_tarfile_data_filter_attribute()
tar.extractall(path=model_data_source.get_root_dir(), filter="data")

volumes.append(_Volume(model_data_source.get_root_dir(), "/opt/ml/model"))

Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/serve/model_server/djl_serving/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List
from pathlib import Path

from sagemaker.utils import _tmpdir
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
from sagemaker.s3 import S3Downloader
from sagemaker.djl_inference import DJLModel
from sagemaker.djl_inference.model import _read_existing_serving_properties
Expand Down Expand Up @@ -53,7 +53,8 @@ def _extract_js_resource(js_model_dir: str, js_id: str):
"""Uncompress the jumpstart resource"""
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
with tarfile.open(str(tmp_sourcedir)) as resources:
resources.extractall(path=js_model_dir)
check_tarfile_data_filter_attribute()
resources.extractall(path=js_model_dir, filter="data")


def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path):
Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/serve/model_server/tgi/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pathlib import Path

from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
from sagemaker.utils import _tmpdir
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
from sagemaker.s3 import S3Downloader

logger = logging.getLogger(__name__)
Expand All @@ -29,7 +29,8 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
"""Uncompress the jumpstart resource"""
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
with tarfile.open(str(tmp_sourcedir)) as resources:
resources.extractall(path=code_dir)
check_tarfile_data_filter_attribute()
resources.extractall(path=code_dir, filter="data")


def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:
Expand Down
29 changes: 27 additions & 2 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import random
import re
import shutil
import sys
import tarfile
import tempfile
import time
Expand Down Expand Up @@ -591,7 +592,8 @@ def _create_or_update_code_dir(
download_file_from_url(source_directory, local_code_path, sagemaker_session)

with tarfile.open(name=local_code_path, mode="r:gz") as t:
t.extractall(path=code_dir)
check_tarfile_data_filter_attribute()
t.extractall(path=code_dir, filter="data")

elif source_directory:
if os.path.exists(code_dir):
Expand Down Expand Up @@ -628,7 +630,8 @@ def _extract_model(model_uri, sagemaker_session, tmp):
else:
local_model_path = model_uri.replace("file://", "")
with tarfile.open(name=local_model_path, mode="r:gz") as t:
t.extractall(path=tmp_model_dir)
check_tarfile_data_filter_attribute()
t.extractall(path=tmp_model_dir, filter="data")
return tmp_model_dir


Expand Down Expand Up @@ -1489,3 +1492,25 @@ def format_tags(tags: Tags) -> List[TagsDict]:
return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()]

return tags


class PythonVersionError(Exception):
"""Raise when a secure [/patched] version of Python is not used."""


def check_tarfile_data_filter_attribute():
"""Check if tarfile has data_filter utility.

Tarfile-data_filter utility has guardrails against untrusted de-serialisation.

Raises:
PythonVersionError: if `tarfile.data_filter` is not available.
"""
# The function and it's usages can be deprecated post support of python >= 3.12
if not hasattr(tarfile, "data_filter"):
raise PythonVersionError(
f"Since tarfile extraction is unsafe the operation is prohibited "
f"per PEP-721. Please update your Python [{sys.version}] "
f"to latest patch [refer to https://www.python.org/downloads/] "
f"to consume the security patch"
)
10 changes: 8 additions & 2 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
Step,
ConfigurableRetryStep,
)
from sagemaker.utils import _save_model, download_file_from_url, format_tags
from sagemaker.utils import (
_save_model,
download_file_from_url,
format_tags,
check_tarfile_data_filter_attribute,
)
from sagemaker.workflow.retry import RetryPolicy
from sagemaker.workflow.utilities import trim_request_dict

Expand Down Expand Up @@ -257,7 +262,8 @@ def _inject_repack_script_and_launcher(self):
download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session)

with tarfile.open(name=old_targz_path, mode="r:gz") as t:
t.extractall(path=targz_contents_dir)
check_tarfile_data_filter_attribute()
t.extractall(path=targz_contents_dir, filter="data")

shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT))
with open(
Expand Down
5 changes: 4 additions & 1 deletion tests/integ/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import boto3
from six.moves.urllib.parse import urlparse

from sagemaker.utils import check_tarfile_data_filter_attribute


def assert_s3_files_exist(sagemaker_session, s3_url, files):
parsed_url = urlparse(s3_url)
Expand Down Expand Up @@ -55,4 +57,5 @@ def extract_files_from_s3(s3_url, tmpdir, sagemaker_session):
s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model)

with tarfile.open(model, "r") as tar_file:
tar_file.extractall(tmpdir)
check_tarfile_data_filter_attribute()
tar_file.extractall(tmpdir, filter="data")
Original file line number Diff line number Diff line change
Expand Up @@ -272,4 +272,4 @@ def test_extract_js_resources_success(self, mock_tarfile, mock_path):

mock_path.assert_called_once_with(js_model_dir)
mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz")
mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir)
mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir, filter="data")
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,4 @@ def test_extract_js_resources_success(self, mock_tarfile, mock_path):

mock_path.assert_called_once_with(js_model_dir)
mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz")
mock_resource_obj.extractall.assert_called_once_with(path=code_dir)
mock_resource_obj.extractall.assert_called_once_with(path=code_dir, filter="data")
5 changes: 3 additions & 2 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mock import Mock, patch

from sagemaker import fw_utils
from sagemaker.utils import name_from_image
from sagemaker.utils import name_from_image, check_tarfile_data_filter_attribute
from sagemaker.session_settings import SessionSettings
from sagemaker.instance_group import InstanceGroup

Expand Down Expand Up @@ -424,7 +424,8 @@ def list_tar_files(folder, tar_ball, tmpdir):
startpath = str(tmpdir.ensure(folder, dir=True))

with tarfile.open(name=tar_ball, mode="r:gz") as t:
t.extractall(path=startpath)
check_tarfile_data_filter_attribute()
t.extractall(path=startpath, filter="data")

def walk():
for root, dirs, files in os.walk(startpath):
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
resolve_nested_dict_value_from_config,
update_list_of_dicts_with_values_from_config,
volume_size_supported,
PythonVersionError,
check_tarfile_data_filter_attribute,
)
from tests.unit.sagemaker.workflow.helpers import CustomStep
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
Expand Down Expand Up @@ -1748,3 +1750,15 @@ def test_instance_family_from_full_instance_type(self):

for instance_type, family in instance_type_to_family_test_dict.items():
self.assertEqual(family, get_instance_type_family(instance_type))


class TestCheckTarfileDataFilterAttribute(TestCase):
def test_check_tarfile_data_filter_attribute_unhappy_case(self):
with pytest.raises(PythonVersionError):
with patch("tarfile.data_filter", None):
delattr(tarfile, "data_filter")
check_tarfile_data_filter_attribute()

def test_check_tarfile_data_filter_attribute_happy_case(self):
with patch("tarfile.data_filter", "some_value"):
check_tarfile_data_filter_attribute()