Skip to content

Commit d533fa7

Browse files
author
Brent Millare
committed
fix black-check
1 parent 3c2fc2b commit d533fa7

File tree

2 files changed

+24
-23
lines changed

2 files changed

+24
-23
lines changed

tests/data/xgboost_abalone/abalone.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,39 @@
55

66
import xgboost as xgb
77

8-
model_filename = 'xgboost-model'
8+
model_filename = "xgboost-model"
99

10-
if __name__ == '__main__':
10+
if __name__ == "__main__":
1111
parser = argparse.ArgumentParser()
1212

1313
# Sagemaker specific arguments. Defaults are set in the environment variables.
14-
parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR','/opt/ml/model'))
15-
parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN','/opt/ml/input/data/abalone'))
14+
parser.add_argument(
15+
"--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
16+
)
17+
parser.add_argument(
18+
"--train",
19+
type=str,
20+
default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/abalone"),
21+
)
1622

1723
args, _ = parser.parse_known_args()
1824

19-
dtrain = get_dmatrix(args.train, 'libsvm')
25+
dtrain = get_dmatrix(args.train, "libsvm")
2026

2127
params = {
22-
'max_depth': 5,
23-
'eta': 0.2,
24-
'gamma': 4,
25-
'min_child_weight': 6,
26-
'subsample': 0.7,
27-
'verbosity': 2,
28-
'objective': 'reg:squarederror',
29-
'tree_method': 'auto',
30-
'predictor': 'auto',
28+
"max_depth": 5,
29+
"eta": 0.2,
30+
"gamma": 4,
31+
"min_child_weight": 6,
32+
"subsample": 0.7,
33+
"verbosity": 2,
34+
"objective": "reg:squarederror",
35+
"tree_method": "auto",
36+
"predictor": "auto",
3137
}
3238

33-
booster = xgb.train(params=params,
34-
dtrain=dtrain,
35-
num_boost_round=50)
36-
booster.save_model(args.model_dir + '/' + model_filename)
39+
booster = xgb.train(params=params, dtrain=dtrain, num_boost_round=50)
40+
booster.save_model(args.model_dir + "/" + model_filename)
3741

3842

3943
def model_fn(model_dir):

tests/integ/test_xgboost.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,10 @@ def test_training_with_network_isolation(
7575

7676
train_input = xgboost.sagemaker_session.upload_data(
7777
path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"),
78-
key_prefix="integ-test-data/xgboost_abalone/abalone"
78+
key_prefix="integ-test-data/xgboost_abalone/abalone",
7979
)
8080
job_name = unique_name_from_base(base_job_name)
81-
xgboost.fit(
82-
inputs={"train": train_input},
83-
job_name=job_name
84-
)
81+
xgboost.fit(inputs={"train": train_input}, job_name=job_name)
8582
assert sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=job_name)[
8683
"EnableNetworkIsolation"
8784
]

0 commit comments

Comments
 (0)