Skip to content

Commit 0510d98

Browse files
committed
fix tests
1 parent 152a50c commit 0510d98

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525

2626
from sagemaker_core.main import resources
2727
from sagemaker_core.resources import TrainingJob
28-
from sagemaker_core.shapes import AlgorithmSpecification
28+
from sagemaker_core.shapes import (
29+
AlgorithmSpecification,
30+
OutputDataConfig,
31+
CheckpointConfig,
32+
TensorBoardOutputConfig
33+
)
2934

3035
from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call
3136

@@ -52,7 +57,6 @@
5257
Compute,
5358
StoppingCondition,
5459
RetryStrategy,
55-
OutputDataConfig,
5660
SourceCode,
5761
TrainingImageConfig,
5862
Channel,
@@ -64,8 +68,6 @@
6468
InfraCheckConfig,
6569
RemoteDebugConfig,
6670
SessionChainingConfig,
67-
TensorBoardOutputConfig,
68-
CheckpointConfig,
6971
InputData,
7072
)
7173

@@ -737,7 +739,7 @@ def train(
737739
sagemaker_session=self.sagemaker_session,
738740
container_entrypoint=algorithm_specification.container_entrypoint,
739741
container_arguments=algorithm_specification.container_arguments,
740-
input_data_config=input_data_config,
742+
input_data_config=self.input_data_config,
741743
hyper_parameters=string_hyper_parameters,
742744
environment=self.environment,
743745
)

0 commit comments

Comments
 (0)