Skip to content

Commit 235fc61

Browse files
committed
Add detail profiler V2 options and tests
1 parent 76b0938 commit 235fc61

File tree

1 file changed

+50
-9
lines changed

1 file changed

+50
-9
lines changed

src/sagemaker/estimator.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,24 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
14391439
Instance of the calling ``Estimator`` Class with the attached
14401440
training job.
14411441
"""
1442+
return cls._attach(
1443+
training_job_name=training_job_name,
1444+
sagemaker_session=sagemaker_session,
1445+
model_channel_name=model_channel_name,
1446+
)
1447+
1448+
@classmethod
1449+
def _attach(
1450+
cls,
1451+
training_job_name: str,
1452+
sagemaker_session: Optional[str] = None,
1453+
model_channel_name: str = "model",
1454+
additional_kwargs: Optional[Dict[str, Any]] = None,
1455+
) -> "EstimatorBase":
1456+
"""Creates an Estimator bound to an existing training job.
1457+
1458+
Additional kwargs are allowed for instantiating Estimator.
1459+
"""
14421460
sagemaker_session = sagemaker_session or Session()
14431461

14441462
job_details = sagemaker_session.sagemaker_client.describe_training_job(
@@ -1450,6 +1468,9 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
14501468
)["Tags"]
14511469
init_params.update(tags=tags)
14521470

1471+
if additional_kwargs:
1472+
init_params.update(additional_kwargs)
1473+
14531474
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
14541475
estimator.latest_training_job = _TrainingJob(
14551476
sagemaker_session=sagemaker_session, job_name=training_job_name
@@ -1751,21 +1772,41 @@ def register(
17511772

17521773
@property
17531774
def model_data(self):
1754-
"""str: The model location in S3. Only set if Estimator has been ``fit()``."""
1775+
"""Str or dict: The model location in S3. Only set if Estimator has been ``fit()``."""
17551776
if self.latest_training_job is not None and not isinstance(
17561777
self.sagemaker_session, PipelineSession
17571778
):
1758-
model_uri = self.sagemaker_session.sagemaker_client.describe_training_job(
1779+
job_details = self.sagemaker_session.sagemaker_client.describe_training_job(
17591780
TrainingJobName=self.latest_training_job.name
1760-
)["ModelArtifacts"]["S3ModelArtifacts"]
1761-
else:
1762-
logger.warning(
1763-
"No finished training job found associated with this estimator. Please make sure "
1764-
"this estimator is only used for building workflow config"
17651781
)
1766-
model_uri = os.path.join(
1767-
self.output_path, self._current_job_name, "output", "model.tar.gz"
1782+
model_uri = job_details["ModelArtifacts"]["S3ModelArtifacts"]
1783+
compression_type = job_details.get("OutputDataConfig", {}).get(
1784+
"CompressionType", "GZIP"
17681785
)
1786+
if compression_type == "GZIP":
1787+
return model_uri
1788+
# fail fast if we don't recognize training output compression type
1789+
if compression_type not in {"GZIP", "NONE"}:
1790+
raise ValueError(
1791+
f'Unrecognized training job output data compression type "{compression_type}"'
1792+
)
1793+
# model data is in uncompressed form NOTE SageMaker Hosting mandates presence of
1794+
# trailing forward slash in S3 model data URI, so append one if necessary.
1795+
if not model_uri.endswith("/"):
1796+
model_uri += "/"
1797+
return {
1798+
"S3DataSource": {
1799+
"S3Uri": model_uri,
1800+
"S3DataType": "S3Prefix",
1801+
"CompressionType": "None",
1802+
}
1803+
}
1804+
1805+
logger.warning(
1806+
"No finished training job found associated with this estimator. Please make sure "
1807+
"this estimator is only used for building workflow config"
1808+
)
1809+
model_uri = os.path.join(self.output_path, self._current_job_name, "output", "model.tar.gz")
17691810
return model_uri
17701811

17711812
@abstractmethod

0 commit comments

Comments
 (0)