|
22 | 22 | from sagemaker.debugger.debugger import DebuggerHookConfig, RuleBase, TensorBoardOutputConfig
|
23 | 23 | from sagemaker.debugger.profiler_config import ProfilerConfig
|
24 | 24 |
|
25 |
| -from sagemaker.estimator import _TrainingJob, Estimator |
| 25 | +from sagemaker.estimator import Estimator |
26 | 26 | from sagemaker.explainer.explainer_config import ExplainerConfig
|
27 | 27 | from sagemaker.inputs import FileSystemInput, TrainingInput
|
28 | 28 | from sagemaker.instance_group import InstanceGroup
|
@@ -664,10 +664,10 @@ def attach(
|
664 | 664 | model_version: str = "*",
|
665 | 665 | sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
666 | 666 | model_channel_name: str = "model",
|
667 |
| - ): |
| 667 | + ) -> "JumpStartEstimator": |
668 | 668 | """Attach to an existing training job.
|
669 | 669 |
|
670 |
| - Create an Estimator bound to an existing training job. |
| 670 | + Create a JumpStartEstimator bound to an existing training job. |
671 | 671 | After attaching, if the training job has a Complete status,
|
672 | 672 | it can be ``deploy()`` ed to create a SageMaker Endpoint and return
|
673 | 673 | a ``Predictor``.
|
@@ -705,23 +705,12 @@ def attach(
|
705 | 705 | training job.
|
706 | 706 | """
|
707 | 707 |
|
708 |
| - job_details = sagemaker_session.sagemaker_client.describe_training_job( |
709 |
| - TrainingJobName=training_job_name |
710 |
| - ) |
711 |
| - init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name) |
712 |
| - tags = sagemaker_session.sagemaker_client.list_tags( |
713 |
| - ResourceArn=job_details["TrainingJobArn"] |
714 |
| - )["Tags"] |
715 |
| - init_params.update(tags=tags) |
716 |
| - init_params.update(model_id=model_id, model_version=model_version) |
717 |
| - |
718 |
| - estimator = cls(sagemaker_session=sagemaker_session, **init_params) |
719 |
| - estimator.latest_training_job = _TrainingJob( |
720 |
| - sagemaker_session=sagemaker_session, job_name=training_job_name |
| 708 | + return cls._attach( |
| 709 | + training_job_name=training_job_name, |
| 710 | + sagemaker_session=sagemaker_session, |
| 711 | + model_channel_name=model_channel_name, |
| 712 | + additional_kwargs={"model_id": model_id, "model_version": model_version}, |
721 | 713 | )
|
722 |
| - estimator._current_job_name = estimator.latest_training_job.name |
723 |
| - estimator.latest_training_job.wait(logs="None") |
724 |
| - return estimator |
725 | 714 |
|
726 | 715 | def deploy(
|
727 | 716 | self,
|
|
0 commit comments