Skip to content

Commit da64c38

Browse files
committed
chore: ensure parity between estimator and framework hyperparams in unit test
1 parent 6c9b085 commit da64c38

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

tests/unit/test_estimator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
from copy import deepcopy
1415

1516
import logging
1617
import json
@@ -3825,6 +3826,12 @@ def test_script_mode_estimator_same_calls_as_framework(
38253826

38263827
model_uri = "s3://someprefix2/models/model.tar.gz"
38273828
training_data_uri = "s3://bucket/mydata"
3829+
hyperparameters = {
3830+
"int_hyperparam": 1,
3831+
"string_hyperparam": "hello",
3832+
"stringified_numeric_hyperparam": "44",
3833+
"float_hyperparam": 1.234,
3834+
}
38283835

38293836
generic_estimator = Estimator(
38303837
entry_point=SCRIPT_PATH,
@@ -3838,6 +3845,7 @@ def test_script_mode_estimator_same_calls_as_framework(
38383845
model_uri=model_uri,
38393846
dependencies=[],
38403847
debugger_hook_config={},
3848+
hyperparameters=deepcopy(hyperparameters),
38413849
)
38423850
generic_estimator.fit(training_data_uri)
38433851

@@ -3858,6 +3866,7 @@ def test_script_mode_estimator_same_calls_as_framework(
38583866
model_uri=model_uri,
38593867
dependencies=[],
38603868
debugger_hook_config={},
3869+
hyperparameters=deepcopy(hyperparameters),
38613870
)
38623871
framework_estimator.fit(training_data_uri)
38633872

0 commit comments

Comments
 (0)