16
16
import os
17
17
import re
18
18
import shutil
19
+ import sys
19
20
from typing import Dict , List , Tuple
20
21
21
22
from sagemaker .config import config_schema
@@ -120,8 +121,12 @@ def __init__(
120
121
environment_variables , config_schema .ENVIRONMENT_VARIABLES
121
122
)
122
123
123
- # TODO: provide default image uri if not set
124
- self .image_uri = self ._get_from_config (image_uri , config_schema .IMAGE_URI , required = True )
124
+ _image_uri = self ._get_from_config (image_uri , config_schema .IMAGE_URI )
125
+ if _image_uri :
126
+ self .image_uri = _image_uri
127
+ else :
128
+ self .image_uri = self ._get_default_image (self .sagemaker_session )
129
+
125
130
self .dependencies = self ._get_from_config (dependencies , config_schema .DEPENDENCIES )
126
131
127
132
self .instance_type = self ._get_from_config (
@@ -171,7 +176,7 @@ def _get_from_config(
171
176
default = None ,
172
177
required = False ,
173
178
):
174
- """Get default value from sagemaker config."""
179
+ """Get value from sagemaker config."""
175
180
if override_value :
176
181
return override_value
177
182
config_value = self .sagemaker_config .get_config_value (
@@ -189,6 +194,33 @@ def _get_from_config(
189
194
raise ValueError (f"{ sagemaker_config_key } is a required parameter!" )
190
195
return default
191
196
197
+ @staticmethod
198
+ def _get_default_image (session ):
199
+ """Return Studio notebook image, if in Studio env. Else, base python"""
200
+
201
+ if (
202
+ "SAGEMAKER_INTERNAL_IMAGE_URI" in os .environ
203
+ and os .environ ["SAGEMAKER_INTERNAL_IMAGE_URI" ]
204
+ ):
205
+ return os .environ ["SAGEMAKER_INTERNAL_IMAGE_URI" ]
206
+
207
+ py_major_version = sys .version_info [0 ]
208
+ py_minor_version = sys .version_info [1 ]
209
+
210
+ # TODO:Add Support for 3.8
211
+ if py_major_version != 3 or py_minor_version != 10 :
212
+ raise ValueError ("Use supported Python version or provide compatible ImageUri." )
213
+
214
+ # TODO: Support only supported by Studio
215
+ region = session .boto_region_name
216
+
217
+ # TODO: Remove beta image and use public base python
218
+ beta_image = (
219
+ f"581474259216.dkr.ecr.{ region } .amazonaws.com/"
220
+ f"sagemaker-pathways-beta:basepy_3_10_latest"
221
+ )
222
+ return beta_image
223
+
192
224
193
225
class _Job :
194
226
"""Helper class that interacts with the SageMaker training service."""
0 commit comments