Skip to content

Commit 77a5a5e

Browse files
author
Brent Millare
committed
Add integ tests for xgboost net iso
1 parent b038943 commit 77a5a5e

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

tests/data/xgboost_abalone/abalone.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import argparse
2+
import os
3+
4+
from sagemaker_xgboost_container.data_utils import get_dmatrix
5+
6+
import xgboost as xgb
7+
8+
model_filename = 'xgboost-model'
9+
10+
if __name__ == '__main__':
11+
parser = argparse.ArgumentParser()
12+
13+
# Sagemaker specific arguments. Defaults are set in the environment variables.
14+
parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR','/opt/ml/model'))
15+
parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN','/opt/ml/input/data/abalone'))
16+
17+
args, _ = parser.parse_known_args()
18+
19+
dtrain = get_dmatrix(args.train, 'libsvm')
20+
21+
params = {
22+
'max_depth': 5,
23+
'eta': 0.2,
24+
'gamma': 4,
25+
'min_child_weight': 6,
26+
'subsample': 0.7,
27+
'verbosity': 2,
28+
'objective': 'reg:squarederror',
29+
'tree_method': 'auto',
30+
'predictor': 'auto',
31+
}
32+
33+
booster = xgb.train(params=params,
34+
dtrain=dtrain,
35+
num_boost_round=50)
36+
booster.save_model(args.model_dir + '/' + model_filename)
37+
38+
39+
def model_fn(model_dir):
40+
"""Deserialize and return fitted model.
41+
42+
Note that this should have the same name as the serialized model in the _xgb_train method
43+
"""
44+
booster = xgb.Booster()
45+
booster.load_model(os.path.join(model_dir, model_filename))
46+
return booster

tests/integ/test_xgboost.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616
import pytest
17+
from sagemaker.xgboost import XGBoost
1718
from sagemaker.xgboost.processing import XGBoostProcessor
1819
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
1920
from tests.integ.timeout import timeout
@@ -48,3 +49,37 @@ def test_framework_processing_job_with_deps(
4849
inputs=[],
4950
wait=True,
5051
)
52+
53+
54+
def test_training_with_network_isolation(
55+
sagemaker_session,
56+
xgboost_latest_version,
57+
xgboost_latest_py_version,
58+
cpu_instance_type,
59+
):
60+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
61+
base_job_name = "test-network-isolation-xgboost"
62+
63+
xgboost = XGBoost(
64+
entry_point=os.path.join(DATA_DIR, "xgboost_abalone", "abalone.py"),
65+
role=ROLE,
66+
instance_type=cpu_instance_type,
67+
instance_count=1,
68+
framework_version=xgboost_latest_version,
69+
py_version=xgboost_latest_py_version,
70+
base_job_name=base_job_name,
71+
sagemaker_session=sagemaker_session,
72+
enable_network_isolation=True,
73+
)
74+
75+
train_input = xgboost.sagemaker_session.upload_data(
76+
path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"),
77+
key_prefix="integ-test-data/xgboost_abalone/abalone"
78+
)
79+
xgboost.fit(
80+
inputs={"train": train_input},
81+
job_name=unique_name_from_base(base_job_name)
82+
)
83+
assert sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=job_name)[
84+
"EnableNetworkIsolation"
85+
]

0 commit comments

Comments
 (0)