Skip to content

change: Enable load_run without name args in Transform env #3585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions src/sagemaker/experiments/_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
import logging
import os

from sagemaker import Session
from sagemaker.experiments import trial_component
from sagemaker.utils import retry_with_backoff

TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN"
PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json"
TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH"
TRANSFORM_JOB_ARN_ENV = "TRANSFORM_JOB_ARN"
MAX_RETRY_ATTEMPTS = 7

logger = logging.getLogger(__name__)
Expand All @@ -40,7 +41,7 @@ class _EnvironmentType(enum.Enum):
class _RunEnvironment(object):
"""Retrieves job specific data from the environment."""

def __init__(self, environment_type, source_arn):
def __init__(self, environment_type: _EnvironmentType, source_arn: str):
"""Init for _RunEnvironment.
Args:
Expand All @@ -53,9 +54,9 @@ def __init__(self, environment_type, source_arn):
@classmethod
def load(
cls,
training_job_arn_env=TRAINING_JOB_ARN_ENV,
processing_job_config_path=PROCESSING_JOB_CONFIG_PATH,
transform_job_batch_var=TRANSFORM_JOB_ENV_BATCH_VAR,
training_job_arn_env: str = TRAINING_JOB_ARN_ENV,
processing_job_config_path: str = PROCESSING_JOB_CONFIG_PATH,
transform_job_arn_env: str = TRANSFORM_JOB_ARN_ENV,
):
"""Loads source arn of current job from environment.
Expand All @@ -64,8 +65,8 @@ def load(
(default: `TRAINING_JOB_ARN`).
processing_job_config_path (str): The processing job config path
(default: `/opt/ml/config/processingjobconfig.json`).
transform_job_batch_var (str): The environment variable indicating if
it is a transform job (default: `SAGEMAKER_BATCH`).
transform_job_arn_env (str): The environment key for transform job ARN
(default: `TRANSFORM_JOB_ARN_ENV`).
Returns:
_RunEnvironment: Job data loaded from the environment. None if config does not exist.
Expand All @@ -78,16 +79,15 @@ def load(
environment_type = _EnvironmentType.SageMakerProcessingJob
source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"]
return _RunEnvironment(environment_type, source_arn)
if transform_job_batch_var in os.environ and os.environ[transform_job_batch_var] == "true":
if transform_job_arn_env in os.environ:
environment_type = _EnvironmentType.SageMakerTransformJob
# TODO: need to figure out how to get source_arn from job env
# with Transform team's help.
source_arn = ""
# TODO: need to update to get source_arn from config file once Transform side ready
source_arn = os.environ.get(transform_job_arn_env)
return _RunEnvironment(environment_type, source_arn)

return None

def get_trial_component(self, sagemaker_session):
def get_trial_component(self, sagemaker_session: Session):
"""Retrieves the trial component from the job in the environment.
Args:
Expand All @@ -99,14 +99,6 @@ def get_trial_component(self, sagemaker_session):
Returns:
_TrialComponent: The trial component created from the job. None if not found.
"""
# TODO: Remove this condition check once we have a way to retrieve source ARN
# from transform job env
if self.environment_type == _EnvironmentType.SageMakerTransformJob:
logger.error(
"Currently getting the job trial component from the transform job environment "
"is not supported. Returning None."
)
return None

def _get_trial_component():
summaries = list(
Expand Down
80 changes: 0 additions & 80 deletions src/sagemaker/experiments/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from __future__ import absolute_import

import datetime
import json
import logging
import os
import time
Expand All @@ -35,85 +34,6 @@
logger = logging.getLogger(__name__)


# TODO: remove this _SageMakerFileMetricsWriter class
# when _MetricsManager is fully ready
class _SageMakerFileMetricsWriter(object):
"""Write metric data to file."""

def __init__(self, metrics_file_path=None):
"""Construct a `_SageMakerFileMetricsWriter` object"""
self._metrics_file_path = metrics_file_path
self._file = None
self._closed = False

def log_metric(self, metric_name, value, timestamp=None, step=None):
"""Write a metric to file.
Args:
metric_name (str): The name of the metric.
value (float): The value of the metric.
timestamp (datetime.datetime): Timestamp of the metric.
If not specified, the current UTC time will be used.
step (int): Iteration number of the metric (default: None).
Raises:
SageMakerMetricsWriterException: If the metrics file is closed.
AttributeError: If file has been initialized and the writer hasn't been closed.
"""
raw_metric_data = _RawMetricData(
metric_name=metric_name, value=value, timestamp=timestamp, step=step
)
try:
logger.debug("Writing metric: %s", raw_metric_data)
self._file.write(json.dumps(raw_metric_data.to_record()))
self._file.write("\n")
except AttributeError as attr_err:
if self._closed:
raise SageMakerMetricsWriterException("log_metric called on a closed writer")
if not self._file:
self._file = open(self._get_metrics_file_path(), "a", buffering=1)
self._file.write(json.dumps(raw_metric_data.to_record()))
self._file.write("\n")
else:
raise attr_err

def close(self):
"""Closes the metric file."""
if not self._closed and self._file:
self._file.close()
self._file = None # invalidate reference, causing subsequent log_metric to fail.
self._closed = True

def __enter__(self):
"""Return self"""
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Execute self.close()"""
self.close()

def __del__(self):
"""Execute self.close()"""
self.close()

def _get_metrics_file_path(self):
"""Get file path to store metrics"""
pid_filename = "{}.json".format(str(os.getpid()))
metrics_file_path = self._metrics_file_path or os.path.join(METRICS_DIR, pid_filename)
logger.debug("metrics_file_path = %s", metrics_file_path)
return metrics_file_path


class SageMakerMetricsWriterException(Exception):
"""SageMakerMetricsWriterException"""

def __init__(self, message, errors=None):
"""Construct a `SageMakerMetricsWriterException` instance"""
super().__init__(message)
if errors:
self.errors = errors


class _RawMetricData(object):
"""A Raw Metric Data Object"""

Expand Down
8 changes: 3 additions & 5 deletions src/sagemaker/experiments/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,9 @@ def get_tc_and_exp_config_from_job_env(
num_attempts=4,
)
else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob
raise RuntimeError(
"Failed to load the Run as loading experiment config "
"from transform job environment is not currently supported. "
"As a workaround, please explicitly pass in "
"the experiment_name and run_name in load_run."
job_response = retry_with_backoff(
callable_func=lambda: sagemaker_session.describe_transform_job(job_name),
num_attempts=4,
)

job_exp_config = job_response.get("ExperimentConfig", dict())
Expand Down
11 changes: 5 additions & 6 deletions src/sagemaker/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,18 @@ def __init__(
estimator.fit(job_name="my-job") # Create a training job

In order to reuse an existing run to log extra data, ``load_run`` is recommended.
For example, instead of the ``Run`` constructor, the ``load_run`` is recommended to use
in a job script to load the existing run created before the job launch.
Otherwise, a new run may be created each time you launch a job.

The code snippet below displays how to load the run initialized above
in a custom training job script, where no ``run_name`` or ``experiment_name``
is presented as they are automatically retrieved from the experiment config
in the job environment.

Note:
Instead of the ``Run`` constructor, the ``load_run`` is recommended to use
in a job script to load the existing run created before the job launch.
Otherwise, a new run may be created each time you launch a job.

.. code:: python

with load_run() as run:
with load_run(sagemaker_session=sagemaker_session) as run:
run.log_metric(...)
...

Expand Down
3 changes: 3 additions & 0 deletions tests/data/experiment/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def model_fn(model_dir):
run.log_parameters({"p3": 3.0, "p4": 4.0})
run.log_metric("test-job-load-log-metric", 0.1)

with load_run(sagemaker_session=sagemaker_session) as run:
run.log_parameters({"p5": 5.0, "p6": 6})

model_file = "xgboost-model"
booster = pkl.load(open(os.path.join(model_dir, model_file), "rb"))
return booster
Expand Down
3 changes: 1 addition & 2 deletions tests/integ/sagemaker/experiments/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def verify_metrics():
sagemaker_session=sagemaker_session,
)
metrics = updated_tc.metrics
# TODO: revert to len(metrics) == 2 once backend fix reaches prod
assert len(metrics) > 0
assert len(metrics) == 2
assert list(filter(lambda x: x.metric_name == "test-x-step", metrics))
assert list(filter(lambda x: x.metric_name == "test-x-timestamp", metrics))

Expand Down
17 changes: 7 additions & 10 deletions tests/integ/sagemaker/experiments/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,9 @@ def test_run_from_processing_job_and_override_default_exp_config(
def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, xgboost_latest_version):
# Notes:
# 1. The 1st Run (run) created locally
# 2. In the inference script running in a transform job, load the 1st Run
# via explicitly passing the experiment_name and run_name of the 1st Run
# TODO: once we're able to retrieve exp config from the transform job env,
# we should expand this test and add the load_run() without explicitly supplying the names
# 2. In the inference script running in a transform job, load the 1st Run twice and log data
# 1) via explicitly passing the experiment_name and run_name of the 1st Run
# 2) use load_run() without explicitly supplying the names
# 3. All data are logged in the Run either locally or in the transform job
exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT)
xgb_model_data_s3 = sagemaker_session.upload_data(
Expand Down Expand Up @@ -494,6 +493,7 @@ def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, xgboost_latest_v
content_type="text/libsvm",
split_type="Line",
wait=True,
logs=False,
job_name=f"transform-job-{name()}",
)

Expand All @@ -506,7 +506,7 @@ def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, xgboost_latest_v
experiment_name=run.experiment_name, run_name=run.run_name
)
_check_run_from_job_result(
tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False
tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False, has_extra_load=True
)


Expand Down Expand Up @@ -636,8 +636,7 @@ def _check_run_from_local_end_result(sagemaker_session, tc, is_complete_log=True
assert "s3://Input" == tc.input_artifacts[artifact_name].value
assert not tc.input_artifacts[artifact_name].media_type

# TODO: revert to len(tc.metrics) == 1 once backend fix reaches prod
assert len(tc.metrics) > 0
assert len(tc.metrics) == 1
metric_summary = tc.metrics[0]
assert metric_summary.metric_name == metric_name
assert metric_summary.max == 9.0
Expand All @@ -651,9 +650,7 @@ def validate_tc_updated_in_init():
assert tc.status.primary_status == _TrialComponentStatusType.Completed.value
assert tc.parameters["p1"] == 1.0
assert tc.parameters["p2"] == 2.0
# TODO: revert to assert len(tc.metrics) == 5 once
# backend fix hits prod
assert len(tc.metrics) > 0
assert len(tc.metrics) == 5
for metric_summary in tc.metrics:
# metrics deletion is not supported at this point
# so its count would accumulate
Expand Down
25 changes: 10 additions & 15 deletions tests/unit/sagemaker/experiments/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from sagemaker.experiments import _environment
from sagemaker.experiments._environment import TRANSFORM_JOB_ARN_ENV, TRAINING_JOB_ARN_ENV
from sagemaker.utils import retry_with_backoff


Expand All @@ -33,22 +34,22 @@ def tempdir():

@pytest.fixture
def training_job_env():
old_value = os.environ.get("TRAINING_JOB_ARN")
os.environ["TRAINING_JOB_ARN"] = "arn:1234aBcDe"
old_value = os.environ.get(TRAINING_JOB_ARN_ENV)
os.environ[TRAINING_JOB_ARN_ENV] = "arn:1234aBcDe"
yield os.environ
del os.environ["TRAINING_JOB_ARN"]
del os.environ[TRAINING_JOB_ARN_ENV]
if old_value:
os.environ["TRAINING_JOB_ARN"] = old_value
os.environ[TRAINING_JOB_ARN_ENV] = old_value


@pytest.fixture
def transform_job_env():
old_value = os.environ.get("SAGEMAKER_BATCH")
os.environ["SAGEMAKER_BATCH"] = "true"
old_value = os.environ.get(TRANSFORM_JOB_ARN_ENV)
os.environ[TRANSFORM_JOB_ARN_ENV] = "arn:1234aBcDe"
yield os.environ
del os.environ["SAGEMAKER_BATCH"]
del os.environ[TRANSFORM_JOB_ARN_ENV]
if old_value:
os.environ["SAGEMAKER_BATCH"] = old_value
os.environ[TRANSFORM_JOB_ARN_ENV] = old_value


def test_processing_job_environment(tempdir):
Expand All @@ -70,8 +71,7 @@ def test_training_job_environment(training_job_env):
def test_transform_job_environment(transform_job_env):
environment = _environment._RunEnvironment.load()
assert _environment._EnvironmentType.SageMakerTransformJob == environment.environment_type
# TODO: update if we figure out how to get source_arn from the transform job
assert not environment.source_arn
assert "arn:1234aBcDe" == environment.source_arn


def test_no_environment():
Expand Down Expand Up @@ -100,8 +100,3 @@ def test_resolve_trial_component_fails(mock_retry, sagemaker_session, training_j
client.list_trial_components.side_effect = Exception("Failed test")
environment = _environment._RunEnvironment.load()
assert environment.get_trial_component(sagemaker_session) is None


def test_resolve_transform_job_trial_component_fail(transform_job_env, sagemaker_session):
environment = _environment._RunEnvironment.load()
assert environment.get_trial_component(sagemaker_session) is None
Loading