@@ -256,6 +256,11 @@ def _build_for_tgi_jumpstart(self):
256
256
257
257
self .pysdk_model .env .update (env )
258
258
259
+ def _logging_debug (self , message ):
260
+ logging .debug ("**************************************" )
261
+ logging .debug (message )
262
+ logging .debug ("**************************************" )
263
+
259
264
def _tune_for_js (
260
265
self ,
261
266
num_shard_env_var : str = "SM_NUM_GPUS" ,
@@ -286,9 +291,15 @@ def _tune_for_js(
286
291
initial_env_vars = copy .deepcopy (self .pysdk_model .env )
287
292
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees (self .js_model_config )
288
293
294
+ self ._logging_debug (
295
+ f"initial_env_vars: { initial_env_vars } ,"
296
+ f" admissible_tensor_parallel_degrees: { admissible_tensor_parallel_degrees } " )
297
+
289
298
available_gpus = None
290
299
if multiple_model_copies_enabled :
291
300
available_gpus = _get_available_gpus ()
301
+ self ._logging_debug (
302
+ f"multiple_model_copies_enabled: { multiple_model_copies_enabled } , available_gpus: { available_gpus } " )
292
303
293
304
benchmark_results = {}
294
305
best_tuned_combination = None
@@ -297,18 +308,30 @@ def _tune_for_js(
297
308
if datetime .now () > timeout :
298
309
logger .info ("Max tuning duration reached. Tuning stopped." )
299
310
break
311
+ try :
312
+ self .pysdk_model .env .update ({
313
+ num_shard_env_var : str (tensor_parallel_degree )
314
+ })
315
+ self ._logging_debug (
316
+ f"num_shard_env_var: { num_shard_env_var } , tensor_parallel_degree: { tensor_parallel_degree } " )
317
+ except Exception as e :
318
+ self ._logging_debug (str (e ))
300
319
301
- self .pysdk_model .env .update ({
302
- num_shard_env_var : str (tensor_parallel_degree )
303
- })
304
320
logging_msg = f"{ num_shard_env_var } : { tensor_parallel_degree } ."
305
321
306
322
sagemaker_model_server_workers = None
307
323
if multiple_model_copies_enabled :
308
324
sagemaker_model_server_workers = int (available_gpus / tensor_parallel_degree )
309
- self .pysdk_model .env .update ({
310
- num_model_copies_env_var : str (sagemaker_model_server_workers )
311
- })
325
+ self ._logging_debug (f"sagemaker_model_server_workers: { sagemaker_model_server_workers } " )
326
+ try :
327
+ self .pysdk_model .env .update ({
328
+ num_model_copies_env_var : str (sagemaker_model_server_workers )
329
+ })
330
+ self ._logging_debug (
331
+ f"num_model_copies_env_var: { num_model_copies_env_var } , "
332
+ f"sagemaker_model_server_workers: { sagemaker_model_server_workers } " )
333
+ except Exception as e :
334
+ self ._logging_debug (str (e ))
312
335
logging_msg = f"{ logging_msg } { num_model_copies_env_var } : { sagemaker_model_server_workers } ."
313
336
314
337
try :
0 commit comments