Skip to content

breaking: rename session parameter to sagemaker_session in S3 utilities #1663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _prepare_rules(self):
s3_uri = S3Uploader.upload(
local_path=rule.rule_parameters["source_s3_uri"],
desired_s3_uri=desired_s3_uri,
session=self.sagemaker_session,
sagemaker_session=self.sagemaker_session,
)
rule.rule_parameters["source_s3_uri"] = s3_uri
# Save the request dictionary for the rule.
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/model_monitor/model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ def _normalize_baseline_inputs(self, baseline_inputs=None):
S3Uploader.upload(
local_path=file_input.source,
desired_s3_uri=s3_uri,
session=self.sagemaker_session,
sagemaker_session=self.sagemaker_session,
)
file_input.source = s3_uri
normalized_inputs.append(file_input)
Expand Down Expand Up @@ -944,7 +944,7 @@ def _s3_uri_from_local_path(self, path):
str(uuid.uuid4()),
)
S3Uploader.upload(
local_path=path, desired_s3_uri=s3_uri, session=self.sagemaker_session
local_path=path, desired_s3_uri=s3_uri, sagemaker_session=self.sagemaker_session
)
path = os.path.join(s3_uri, os.path.basename(path))
return path
Expand Down Expand Up @@ -1771,7 +1771,7 @@ def _upload_and_convert_to_processing_input(self, source, destination, name):
name,
)
S3Uploader.upload(
local_path=source, desired_s3_uri=s3_uri, session=self.sagemaker_session
local_path=source, desired_s3_uri=s3_uri, sagemaker_session=self.sagemaker_session
)
source = s3_uri

Expand Down
18 changes: 11 additions & 7 deletions src/sagemaker/model_monitor/monitoring_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def save(self, new_save_location_s3_uri=None):
body=json.dumps(self.body_dict),
desired_s3_uri=self.file_s3_uri,
kms_key=self.kms_key,
session=self.session,
sagemaker_session=self.session,
)


Expand Down Expand Up @@ -119,7 +119,9 @@ def from_s3_uri(cls, statistics_file_s3_uri, kms_key=None, sagemaker_session=Non
"""
try:
body_dict = json.loads(
S3Downloader.read_file(s3_uri=statistics_file_s3_uri, session=sagemaker_session)
S3Downloader.read_file(
s3_uri=statistics_file_s3_uri, sagemaker_session=sagemaker_session
)
)
except ClientError as error:
print(
Expand Down Expand Up @@ -163,7 +165,7 @@ def from_string(
body=statistics_file_string,
desired_s3_uri=desired_s3_uri,
kms_key=kms_key,
session=sagemaker_session,
sagemaker_session=sagemaker_session,
)

return Statistics.from_s3_uri(
Expand Down Expand Up @@ -243,7 +245,9 @@ def from_s3_uri(cls, constraints_file_s3_uri, kms_key=None, sagemaker_session=No
"""
try:
body_dict = json.loads(
S3Downloader.read_file(s3_uri=constraints_file_s3_uri, session=sagemaker_session)
S3Downloader.read_file(
s3_uri=constraints_file_s3_uri, sagemaker_session=sagemaker_session
)
)
except ClientError as error:
print(
Expand Down Expand Up @@ -290,7 +294,7 @@ def from_string(
body=constraints_file_string,
desired_s3_uri=desired_s3_uri,
kms_key=kms_key,
session=sagemaker_session,
sagemaker_session=sagemaker_session,
)

return Constraints.from_s3_uri(
Expand Down Expand Up @@ -396,7 +400,7 @@ def from_s3_uri(cls, constraint_violations_file_s3_uri, kms_key=None, sagemaker_
try:
body_dict = json.loads(
S3Downloader.read_file(
s3_uri=constraint_violations_file_s3_uri, session=sagemaker_session
s3_uri=constraint_violations_file_s3_uri, sagemaker_session=sagemaker_session
)
)
except ClientError as error:
Expand Down Expand Up @@ -445,7 +449,7 @@ def from_string(
body=constraint_violations_file_string,
desired_s3_uri=desired_s3_uri,
kms_key=kms_key,
session=sagemaker_session,
sagemaker_session=sagemaker_session,
)

return ConstraintViolations.from_s3_uri(
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _normalize_inputs(self, inputs=None):
s3_uri = S3Uploader.upload(
local_path=file_input.source,
desired_s3_uri=desired_s3_uri,
session=self.sagemaker_session,
sagemaker_session=self.sagemaker_session,
)
file_input.source = s3_uri
normalized_inputs.append(file_input)
Expand Down Expand Up @@ -480,7 +480,7 @@ def _upload_code(self, code):
self._CODE_CONTAINER_INPUT_NAME,
)
return S3Uploader.upload(
local_path=code, desired_s3_uri=desired_s3_uri, session=self.sagemaker_session
local_path=code, desired_s3_uri=desired_s3_uri, sagemaker_session=self.sagemaker_session
)

def _convert_code_and_add_to_inputs(self, inputs, s3_uri):
Expand Down
73 changes: 24 additions & 49 deletions src/sagemaker/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,6 @@

logger = logging.getLogger("sagemaker")

SESSION_V2_RENAME_MESSAGE = (
"Parameter 'session' will be renamed to 'sagemaker_session' in SageMaker Python SDK v2."
)


def _session_v2_rename_warning(session):
"""
Args:
session (sagemaker.session.Session):
"""
if session is not None:
logger.warning(SESSION_V2_RENAME_MESSAGE)


def parse_s3_url(url):
"""Returns an (s3 bucket, key name/prefix) tuple from a url with an s3
Expand All @@ -53,15 +40,15 @@ class S3Uploader(object):
"""Contains static methods for uploading directories or files to S3."""

@staticmethod
def upload(local_path, desired_s3_uri, kms_key=None, session=None):
def upload(local_path, desired_s3_uri, kms_key=None, sagemaker_session=None):
"""Static method that uploads a given file or directory to S3.

Args:
local_path (str): Path (absolute or relative) of local file or directory to upload.
desired_s3_uri (str): The desired S3 location to upload to. It is the prefix to
which the local filename will be added.
kms_key (str): The KMS key to use to encrypt the files.
session (sagemaker.session.Session): Session object which
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
Expand All @@ -70,10 +57,7 @@ def upload(local_path, desired_s3_uri, kms_key=None, session=None):
The S3 uri of the uploaded file(s).

"""
if session is not None:
_session_v2_rename_warning(session)

sagemaker_session = session or Session()
sagemaker_session = sagemaker_session or Session()
bucket, key_prefix = parse_s3_url(url=desired_s3_uri)
if kms_key is not None:
extra_args = {"SSEKMSKeyId": kms_key}
Expand All @@ -85,24 +69,23 @@ def upload(local_path, desired_s3_uri, kms_key=None, session=None):
)

@staticmethod
def upload_string_as_file_body(body, desired_s3_uri=None, kms_key=None, session=None):
def upload_string_as_file_body(body, desired_s3_uri=None, kms_key=None, sagemaker_session=None):
"""Static method that uploads a given file or directory to S3.

Args:
body (str): String representing the body of the file.
desired_s3_uri (str): The desired S3 uri to upload to.
kms_key (str): The KMS key to use to encrypt the files.
session (sagemaker.session.Session): AWS session to use. Automatically
generates one if not provided.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.

Returns:
str: The S3 uri of the uploaded file(s).

"""
if session is not None:
_session_v2_rename_warning(session)

sagemaker_session = session or Session()
sagemaker_session = sagemaker_session or Session()
bucket, key = parse_s3_url(desired_s3_uri)

sagemaker_session.upload_string_as_file_body(
Expand All @@ -116,23 +99,19 @@ class S3Downloader(object):
"""Contains static methods for downloading directories or files from S3."""

@staticmethod
def download(s3_uri, local_path, kms_key=None, session=None):
def download(s3_uri, local_path, kms_key=None, sagemaker_session=None):
"""Static method that downloads a given S3 uri to the local machine.

Args:
s3_uri (str): An S3 uri to download from.
local_path (str): A local path to download the file(s) to.
kms_key (str): The KMS key to use to decrypt the files.
session (sagemaker.session.Session): Session object which
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.

"""
if session is not None:
_session_v2_rename_warning(session)

sagemaker_session = session or Session()
sagemaker_session = sagemaker_session or Session()
bucket, key_prefix = parse_s3_url(url=s3_uri)
if kms_key is not None:
extra_args = {"SSECustomerKey": kms_key}
Expand All @@ -144,43 +123,39 @@ def download(s3_uri, local_path, kms_key=None, session=None):
)

@staticmethod
def read_file(s3_uri, session=None):
def read_file(s3_uri, sagemaker_session=None):
"""Static method that returns the contents of an s3 uri file body as a string.

Args:
s3_uri (str): An S3 uri that refers to a single file.
session (sagemaker.session.Session): AWS session to use. Automatically
generates one if not provided.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.

Returns:
str: The body of the file.

"""
if session is not None:
_session_v2_rename_warning(session)

sagemaker_session = session or Session()
sagemaker_session = sagemaker_session or Session()
bucket, key_prefix = parse_s3_url(url=s3_uri)

return sagemaker_session.read_s3_file(bucket=bucket, key_prefix=key_prefix)

@staticmethod
def list(s3_uri, session=None):
def list(s3_uri, sagemaker_session=None):
"""Static method that lists the contents of an S3 uri.

Args:
s3_uri (str): The S3 base uri to list objects in.
session (sagemaker.session.Session): AWS session to use. Automatically
generates one if not provided.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.

Returns:
[str]: The list of S3 URIs in the given S3 base uri.

"""
if session is not None:
_session_v2_rename_warning(session)

sagemaker_session = session or Session()
sagemaker_session = sagemaker_session or Session()
bucket, key_prefix = parse_s3_url(url=s3_uri)

file_keys = sagemaker_session.list_s3_files(bucket=bucket, key_prefix=key_prefix)
Expand Down
12 changes: 6 additions & 6 deletions tests/integ/test_model_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ def test_default_monitor_monitoring_execution_interactions(
desired_s3_uri = os.path.join(executions[-1].output.destination, file_name)

S3Uploader.upload_string_as_file_body(
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
body=file_body, desired_s3_uri=desired_s3_uri, sagemaker_session=sagemaker_session
)

statistics = my_attached_monitor.latest_monitoring_statistics()
Expand All @@ -1522,7 +1522,7 @@ def test_default_monitor_monitoring_execution_interactions(
desired_s3_uri = os.path.join(executions[-1].output.destination, file_name)

S3Uploader.upload_string_as_file_body(
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
body=file_body, desired_s3_uri=desired_s3_uri, sagemaker_session=sagemaker_session
)

constraint_violations = my_attached_monitor.latest_monitoring_constraint_violations()
Expand Down Expand Up @@ -2473,7 +2473,7 @@ def test_byoc_monitor_monitoring_execution_interactions(
desired_s3_uri = os.path.join(executions[-1].output.destination, file_name)

S3Uploader.upload_string_as_file_body(
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
body=file_body, desired_s3_uri=desired_s3_uri, sagemaker_session=sagemaker_session
)

statistics = my_attached_monitor.latest_monitoring_statistics()
Expand All @@ -2486,7 +2486,7 @@ def test_byoc_monitor_monitoring_execution_interactions(
desired_s3_uri = os.path.join(executions[-1].output.destination, file_name)

S3Uploader.upload_string_as_file_body(
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
body=file_body, desired_s3_uri=desired_s3_uri, sagemaker_session=sagemaker_session
)

constraint_violations = my_attached_monitor.latest_monitoring_constraint_violations()
Expand Down Expand Up @@ -2557,10 +2557,10 @@ def _upload_captured_data_to_endpoint(sagemaker_session, predictor):
S3Uploader.upload(
local_path=os.path.join(DATA_DIR, "monitor/captured-data.jsonl"),
desired_s3_uri=s3_uri_previous_hour,
session=sagemaker_session,
sagemaker_session=sagemaker_session,
)
S3Uploader.upload(
local_path=os.path.join(DATA_DIR, "monitor/captured-data.jsonl"),
desired_s3_uri=s3_uri_current_hour,
session=sagemaker_session,
sagemaker_session=sagemaker_session,
)
12 changes: 6 additions & 6 deletions tests/integ/test_monitoring_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_statistics_object_creation_from_s3_uri_with_customizations(
body=file_body,
desired_s3_uri=desired_s3_uri,
kms_key=monitoring_files_kms_key,
session=sagemaker_session,
sagemaker_session=sagemaker_session,
)

statistics = Statistics.from_s3_uri(
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_statistics_object_creation_from_s3_uri_without_customizations(sagemaker
)

s3_uri = S3Uploader.upload_string_as_file_body(
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
body=file_body, desired_s3_uri=desired_s3_uri, sagemaker_session=sagemaker_session
)

statistics = Statistics.from_s3_uri(
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_constraints_object_creation_from_s3_uri_with_customizations(
body=file_body,
desired_s3_uri=desired_s3_uri,
kms_key=monitoring_files_kms_key,
session=sagemaker_session,
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_s3_uri(
Expand Down Expand Up @@ -288,7 +288,7 @@ def test_constraints_object_creation_from_s3_uri_without_customizations(sagemake
)

s3_uri = S3Uploader.upload_string_as_file_body(
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
body=file_body, desired_s3_uri=desired_s3_uri, sagemaker_session=sagemaker_session
)

constraints = Constraints.from_s3_uri(
Expand Down Expand Up @@ -388,7 +388,7 @@ def test_constraint_violations_object_creation_from_s3_uri_with_customizations(
body=file_body,
desired_s3_uri=desired_s3_uri,
kms_key=monitoring_files_kms_key,
session=sagemaker_session,
sagemaker_session=sagemaker_session,
)

constraint_violations = ConstraintViolations.from_s3_uri(
Expand Down Expand Up @@ -419,7 +419,7 @@ def test_constraint_violations_object_creation_from_s3_uri_without_customization
)

s3_uri = S3Uploader.upload_string_as_file_body(
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
body=file_body, desired_s3_uri=desired_s3_uri, sagemaker_session=sagemaker_session
)

constraint_violations = ConstraintViolations.from_s3_uri(
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_multi_variant_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def multi_variant_endpoint(sagemaker_session):
prefix = "sagemaker/DEMO-VariantTargeting"
model_url = S3Uploader.upload(
local_path=XG_BOOST_MODEL_LOCAL_PATH,
desired_s3_uri="s3://" + bucket + "/" + prefix,
session=sagemaker_session,
desired_s3_uri="s3://{}/{}".format(bucket, prefix),
sagemaker_session=sagemaker_session,
)

image_uri = get_image_uri(sagemaker_session.boto_session.region_name, "xgboost", "0.90-1")
Expand Down
Loading