@@ -1439,6 +1439,24 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
1439
1439
Instance of the calling ``Estimator`` Class with the attached
1440
1440
training job.
1441
1441
"""
1442
+ return cls ._attach (
1443
+ training_job_name = training_job_name ,
1444
+ sagemaker_session = sagemaker_session ,
1445
+ model_channel_name = model_channel_name ,
1446
+ )
1447
+
1448
+ @classmethod
1449
+ def _attach (
1450
+ cls ,
1451
+ training_job_name : str ,
1452
+ sagemaker_session : Optional [str ] = None ,
1453
+ model_channel_name : str = "model" ,
1454
+ additional_kwargs : Optional [Dict [str , Any ]] = None ,
1455
+ ) -> "EstimatorBase" :
1456
+ """Creates an Estimator bound to an existing training job.
1457
+
1458
+ Additional kwargs are allowed for instantiating Estimator.
1459
+ """
1442
1460
sagemaker_session = sagemaker_session or Session ()
1443
1461
1444
1462
job_details = sagemaker_session .sagemaker_client .describe_training_job (
@@ -1450,6 +1468,9 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
1450
1468
)["Tags" ]
1451
1469
init_params .update (tags = tags )
1452
1470
1471
+ if additional_kwargs :
1472
+ init_params .update (additional_kwargs )
1473
+
1453
1474
estimator = cls (sagemaker_session = sagemaker_session , ** init_params )
1454
1475
estimator .latest_training_job = _TrainingJob (
1455
1476
sagemaker_session = sagemaker_session , job_name = training_job_name
@@ -1751,21 +1772,41 @@ def register(
1751
1772
1752
1773
@property
1753
1774
def model_data (self ):
1754
- """str : The model location in S3. Only set if Estimator has been ``fit()``."""
1775
+ """Str or dict : The model location in S3. Only set if Estimator has been ``fit()``."""
1755
1776
if self .latest_training_job is not None and not isinstance (
1756
1777
self .sagemaker_session , PipelineSession
1757
1778
):
1758
- model_uri = self .sagemaker_session .sagemaker_client .describe_training_job (
1779
+ job_details = self .sagemaker_session .sagemaker_client .describe_training_job (
1759
1780
TrainingJobName = self .latest_training_job .name
1760
- )["ModelArtifacts" ]["S3ModelArtifacts" ]
1761
- else :
1762
- logger .warning (
1763
- "No finished training job found associated with this estimator. Please make sure "
1764
- "this estimator is only used for building workflow config"
1765
1781
)
1766
- model_uri = os .path .join (
1767
- self .output_path , self ._current_job_name , "output" , "model.tar.gz"
1782
+ model_uri = job_details ["ModelArtifacts" ]["S3ModelArtifacts" ]
1783
+ compression_type = job_details .get ("OutputDataConfig" , {}).get (
1784
+ "CompressionType" , "GZIP"
1768
1785
)
1786
+ if compression_type == "GZIP" :
1787
+ return model_uri
1788
+ # fail fast if we don't recognize training output compression type
1789
+ if compression_type not in {"GZIP" , "NONE" }:
1790
+ raise ValueError (
1791
+ f'Unrecognized training job output data compression type "{ compression_type } "'
1792
+ )
1793
+ # model data is in uncompressed form NOTE SageMaker Hosting mandates presence of
1794
+ # trailing forward slash in S3 model data URI, so append one if necessary.
1795
+ if not model_uri .endswith ("/" ):
1796
+ model_uri += "/"
1797
+ return {
1798
+ "S3DataSource" : {
1799
+ "S3Uri" : model_uri ,
1800
+ "S3DataType" : "S3Prefix" ,
1801
+ "CompressionType" : "None" ,
1802
+ }
1803
+ }
1804
+
1805
+ logger .warning (
1806
+ "No finished training job found associated with this estimator. Please make sure "
1807
+ "this estimator is only used for building workflow config"
1808
+ )
1809
+ model_uri = os .path .join (self .output_path , self ._current_job_name , "output" , "model.tar.gz" )
1769
1810
return model_uri
1770
1811
1771
1812
@abstractmethod
0 commit comments