Skip to content

Commit bc00b8f

Browse files
committed
chore: cleanup code
1 parent 5f1b773 commit bc00b8f

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

src/sagemaker/estimator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,6 +1429,19 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
14291429
Instance of the calling ``Estimator`` Class with the attached
14301430
training job.
14311431
"""
1432+
return cls._attach(
1433+
training_job_name=training_job_name,
1434+
sagemaker_session=sagemaker_session,
1435+
model_channel_name=model_channel_name,
1436+
)
1437+
1438+
def _attach(
1439+
cls,
1440+
training_job_name: str,
1441+
sagemaker_session: Optional[str] = None,
1442+
model_channel_name: str = "model",
1443+
additional_kwargs: Optional[Dict[str, Any]] = None,
1444+
) -> "EstimatorBase":
14321445
sagemaker_session = sagemaker_session or Session()
14331446

14341447
job_details = sagemaker_session.sagemaker_client.describe_training_job(
@@ -1440,6 +1453,9 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
14401453
)["Tags"]
14411454
init_params.update(tags=tags)
14421455

1456+
if additional_kwargs:
1457+
init_params.update(additional_kwargs)
1458+
14431459
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
14441460
estimator.latest_training_job = _TrainingJob(
14451461
sagemaker_session=sagemaker_session, job_name=training_job_name

src/sagemaker/jumpstart/estimator.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sagemaker.debugger.debugger import DebuggerHookConfig, RuleBase, TensorBoardOutputConfig
2323
from sagemaker.debugger.profiler_config import ProfilerConfig
2424

25-
from sagemaker.estimator import _TrainingJob, Estimator
25+
from sagemaker.estimator import Estimator
2626
from sagemaker.explainer.explainer_config import ExplainerConfig
2727
from sagemaker.inputs import FileSystemInput, TrainingInput
2828
from sagemaker.instance_group import InstanceGroup
@@ -664,10 +664,10 @@ def attach(
664664
model_version: str = "*",
665665
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
666666
model_channel_name: str = "model",
667-
):
667+
) -> "JumpStartEstimator":
668668
"""Attach to an existing training job.
669669
670-
Create an Estimator bound to an existing training job.
670+
Create a JumpStartEstimator bound to an existing training job.
671671
After attaching, if the training job has a Complete status,
672672
it can be ``deploy()`` ed to create a SageMaker Endpoint and return
673673
a ``Predictor``.
@@ -705,23 +705,12 @@ def attach(
705705
training job.
706706
"""
707707

708-
job_details = sagemaker_session.sagemaker_client.describe_training_job(
709-
TrainingJobName=training_job_name
710-
)
711-
init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name)
712-
tags = sagemaker_session.sagemaker_client.list_tags(
713-
ResourceArn=job_details["TrainingJobArn"]
714-
)["Tags"]
715-
init_params.update(tags=tags)
716-
init_params.update(model_id=model_id, model_version=model_version)
717-
718-
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
719-
estimator.latest_training_job = _TrainingJob(
720-
sagemaker_session=sagemaker_session, job_name=training_job_name
708+
return cls._attach(
709+
training_job_name=training_job_name,
710+
sagemaker_session=sagemaker_session,
711+
model_channel_name=model_channel_name,
712+
additional_kwargs={"model_id": model_id, "model_version": model_version},
721713
)
722-
estimator._current_job_name = estimator.latest_training_job.name
723-
estimator.latest_training_job.wait(logs="None")
724-
return estimator
725714

726715
def deploy(
727716
self,

0 commit comments

Comments
 (0)