|
5 | 5 |
|
6 | 6 | import xgboost as xgb
|
7 | 7 |
|
8 |
| -model_filename = 'xgboost-model' |
| 8 | +model_filename = "xgboost-model" |
9 | 9 |
|
10 |
| -if __name__ == '__main__': |
| 10 | +if __name__ == "__main__": |
11 | 11 | parser = argparse.ArgumentParser()
|
12 | 12 |
|
13 | 13 | # Sagemaker specific arguments. Defaults are set in the environment variables.
|
14 |
| - parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR','/opt/ml/model')) |
15 |
| - parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN','/opt/ml/input/data/abalone')) |
| 14 | + parser.add_argument( |
| 15 | + "--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model") |
| 16 | + ) |
| 17 | + parser.add_argument( |
| 18 | + "--train", |
| 19 | + type=str, |
| 20 | + default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/abalone"), |
| 21 | + ) |
16 | 22 |
|
17 | 23 | args, _ = parser.parse_known_args()
|
18 | 24 |
|
19 |
| - dtrain = get_dmatrix(args.train, 'libsvm') |
| 25 | + dtrain = get_dmatrix(args.train, "libsvm") |
20 | 26 |
|
21 | 27 | params = {
|
22 |
| - 'max_depth': 5, |
23 |
| - 'eta': 0.2, |
24 |
| - 'gamma': 4, |
25 |
| - 'min_child_weight': 6, |
26 |
| - 'subsample': 0.7, |
27 |
| - 'verbosity': 2, |
28 |
| - 'objective': 'reg:squarederror', |
29 |
| - 'tree_method': 'auto', |
30 |
| - 'predictor': 'auto', |
| 28 | + "max_depth": 5, |
| 29 | + "eta": 0.2, |
| 30 | + "gamma": 4, |
| 31 | + "min_child_weight": 6, |
| 32 | + "subsample": 0.7, |
| 33 | + "verbosity": 2, |
| 34 | + "objective": "reg:squarederror", |
| 35 | + "tree_method": "auto", |
| 36 | + "predictor": "auto", |
31 | 37 | }
|
32 | 38 |
|
33 |
| - booster = xgb.train(params=params, |
34 |
| - dtrain=dtrain, |
35 |
| - num_boost_round=50) |
36 |
| - booster.save_model(args.model_dir + '/' + model_filename) |
| 39 | + booster = xgb.train(params=params, dtrain=dtrain, num_boost_round=50) |
| 40 | + booster.save_model(args.model_dir + "/" + model_filename) |
37 | 41 |
|
38 | 42 |
|
39 | 43 | def model_fn(model_dir):
|
|
0 commit comments