Skip to content

Add TEFN model #406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions config_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import argparse
import json
from pathlib import Path


def get_args():
parser = argparse.ArgumentParser(description='TEFN')

# model define
parser.add_argument('--expand', type=int, default=2, help='expansion factor for Mamba')
parser.add_argument('--d_conv', type=int, default=4, help='conv kernel size for Mamba')
parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock')
parser.add_argument('--num_kernels', type=int, default=6, help='for Inception')
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
parser.add_argument('--c_out', type=int, default=7, help='output size')
parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
parser.add_argument('--factor', type=int, default=1, help='attn factor')
parser.add_argument('--distil', action='store_false',
help='whether to use distilling in encoder, using this argument means not using distilling',
default=True)
parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
parser.add_argument('--embed', type=str, default='timeF',
help='time features encoding, options:[timeF, fixed, learned]')
parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
parser.add_argument('--channel_independence', type=int, default=1,
help='0: channel dependence 1: channel independence for FreTS model')
parser.add_argument('--decomp_method', type=str, default='moving_avg',
help='method of series decompsition, only support moving_avg or dft_decomp')
parser.add_argument('--use_norm', type=int, default=1, help='whether to use normalize; True 1 False 0')
parser.add_argument('--down_sampling_layers', type=int, default=0, help='num of down sampling layers')
parser.add_argument('--down_sampling_window', type=int, default=1, help='down sampling window size')
parser.add_argument('--down_sampling_method', type=str, default=None,
help='down sampling method, only support avg, max, conv')
parser.add_argument('--seg_len', type=int, default=48,
help='the length of segmen-wise iteration of SegRNN')

# optimization
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
parser.add_argument('--itr', type=int, default=1, help='experiments times')
parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
parser.add_argument('--patience', type=int, default=3, help='early stopping patience')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
parser.add_argument('--des', type=str, default='test', help='exp description')
parser.add_argument('--loss', type=str, default='MSE', help='loss function')
parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)

# GPU
parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
parser.add_argument('--gpu', type=int, default=0, help='gpu')
parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')

# de-stationary projector params
parser.add_argument('--p_hidden_dims', type=int, nargs='+', default=[128, 128],
help='hidden layer dimensions of projector (List)')
parser.add_argument('--p_hidden_layers', type=int, default=2, help='number of hidden layers in projector')

# metrics (dtw)
parser.add_argument('--use_dtw', type=bool, default=False,
help='the controller of using dtw metric (dtw is time consuming, not suggested unless necessary)')

# Augmentation
parser.add_argument('--augmentation_ratio', type=int, default=0, help="How many times to augment")
parser.add_argument('--seed', type=int, default=2, help="Randomization seed")
parser.add_argument('--jitter', default=False, action="store_true", help="Jitter preset augmentation")
parser.add_argument('--scaling', default=False, action="store_true", help="Scaling preset augmentation")
parser.add_argument('--permutation', default=False, action="store_true",
help="Equal Length Permutation preset augmentation")
parser.add_argument('--randompermutation', default=False, action="store_true",
help="Random Length Permutation preset augmentation")
parser.add_argument('--magwarp', default=False, action="store_true", help="Magnitude warp preset augmentation")
parser.add_argument('--timewarp', default=False, action="store_true", help="Time warp preset augmentation")
parser.add_argument('--windowslice', default=False, action="store_true", help="Window slice preset augmentation")
parser.add_argument('--windowwarp', default=False, action="store_true", help="Window warp preset augmentation")
parser.add_argument('--rotation', default=False, action="store_true", help="Rotation preset augmentation")
parser.add_argument('--spawner', default=False, action="store_true", help="SPAWNER preset augmentation")
parser.add_argument('--dtwwarp', default=False, action="store_true", help="DTW warp preset augmentation")
parser.add_argument('--shapedtwwarp', default=False, action="store_true", help="Shape DTW warp preset augmentation")
parser.add_argument('--wdba', default=False, action="store_true", help="Weighted DBA preset augmentation")
parser.add_argument('--discdtw', default=False, action="store_true",
help="Discrimitive DTW warp preset augmentation")
parser.add_argument('--discsdtw', default=False, action="store_true",
help="Discrimitive shapeDTW warp preset augmentation")
parser.add_argument('--extra_tag', type=str, default="", help="Anything extra")

args = parser.parse_args()
return args


def load_config(config_path):
args = get_args()
with open(config_path, 'r') as f:
args_tmp = json.loads(f.read())
for k, v in args_tmp.items():
setattr(args, k, v)
return args


def get_all_configs(dir):
dir = Path(dir)
configs = [file for file in dir.rglob("*.json")]
return configs


def update_config(config, args):
with config.open('w') as f:
json.dump(args.__dict__, f)
f.close()


if __name__ == '__main__':
configs = get_all_configs("./configs")
for config in configs:
args = load_config(config)
update_config(config, args)
1 change: 1 addition & 0 deletions configs/long_term_forecast/ECL_script/TEFN_p192.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"expand": 2, "d_conv": 4, "top_k": 5, "num_kernels": 6, "enc_in": 321, "dec_in": 321, "c_out": 321, "d_model": 256, "n_heads": 8, "e_layers": 2, "d_layers": 1, "d_ff": 512, "moving_avg": 25, "factor": 3, "distil": true, "dropout": 0.1, "embed": "timeF", "activation": "gelu", "output_attention": false, "channel_independence": 0, "decomp_method": "moving_avg", "use_norm": 1, "down_sampling_layers": 0, "down_sampling_window": 1, "down_sampling_method": null, "seg_len": 48, "num_workers": 10, "itr": 1, "train_epochs": 10, "batch_size": 32, "patience": 3, "learning_rate": 0.01, "des": "Exp", "loss": "MSE", "lradj": "type1", "use_amp": false, "use_gpu": true, "gpu": 0, "use_multi_gpu": false, "devices": "0,1,2,3", "p_hidden_dims": [128, 128], "p_hidden_layers": 2, "use_dtw": false, "augmentation_ratio": 0, "seed": 2, "jitter": false, "scaling": false, "permutation": false, "randompermutation": false, "magwarp": false, "timewarp": false, "windowslice": false, "windowwarp": false, "rotation": false, "spawner": false, "dtwwarp": false, "shapedtwwarp": false, "wdba": false, "discdtw": false, "discsdtw": false, "extra_tag": "", "task_name": "long_term_forecast", "is_training": 1, "model_id": "ECL_96_192", "model": "TEFN", "data": "custom", "root_path": "./dataset/electricity/", "data_path": "electricity.csv", "features": "M", "target": "OT", "freq": "h", "checkpoints": "./out/checkpoints/", "seq_len": 96, "label_len": 48, "pred_len": 192, "seasonal_patterns": "Monthly", "inverse": false, "mask_rate": 0.25, "anomaly_ratio": 0.25, "conv_kernel": null}
1 change: 1 addition & 0 deletions configs/long_term_forecast/ECL_script/TEFN_p336.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"expand": 2, "d_conv": 4, "top_k": 5, "num_kernels": 6, "enc_in": 321, "dec_in": 321, "c_out": 321, "d_model": 256, "n_heads": 8, "e_layers": 1, "d_layers": 1, "d_ff": 512, "moving_avg": 25, "factor": 3, "distil": true, "dropout": 0.1, "embed": "timeF", "activation": "gelu", "output_attention": false, "channel_independence": 0, "decomp_method": "moving_avg", "use_norm": 1, "down_sampling_layers": 0, "down_sampling_window": 1, "down_sampling_method": null, "seg_len": 48, "num_workers": 10, "itr": 1, "train_epochs": 10, "batch_size": 32, "patience": 3, "learning_rate": 0.05, "des": "Exp", "loss": "MSE", "lradj": "type1", "use_amp": false, "use_gpu": true, "gpu": 0, "use_multi_gpu": false, "devices": "0,1,2,3", "p_hidden_dims": [128, 128], "p_hidden_layers": 2, "use_dtw": false, "augmentation_ratio": 0, "seed": 2, "jitter": false, "scaling": false, "permutation": false, "randompermutation": false, "magwarp": false, "timewarp": false, "windowslice": false, "windowwarp": false, "rotation": false, "spawner": false, "dtwwarp": false, "shapedtwwarp": false, "wdba": false, "discdtw": false, "discsdtw": false, "extra_tag": "", "task_name": "long_term_forecast", "is_training": 1, "model_id": "ECL_96_336", "model": "TEFN", "data": "custom", "root_path": "./dataset/electricity/", "data_path": "electricity.csv", "features": "M", "target": "OT", "freq": "h", "checkpoints": "./out/checkpoints/", "seq_len": 96, "label_len": 48, "pred_len": 336, "seasonal_patterns": "Monthly", "inverse": false, "mask_rate": 0.25, "anomaly_ratio": 0.25, "conv_kernel": null}
1 change: 1 addition & 0 deletions configs/long_term_forecast/ECL_script/TEFN_p720.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"expand": 2, "d_conv": 4, "top_k": 5, "num_kernels": 6, "enc_in": 321, "dec_in": 321, "c_out": 321, "d_model": 256, "n_heads": 8, "e_layers": 3, "d_layers": 1, "d_ff": 512, "moving_avg": 25, "factor": 3, "distil": true, "dropout": 0.1, "embed": "timeF", "activation": "gelu", "output_attention": false, "channel_independence": 0, "decomp_method": "moving_avg", "use_norm": 1, "down_sampling_layers": 0, "down_sampling_window": 1, "down_sampling_method": null, "seg_len": 48, "num_workers": 10, "itr": 1, "train_epochs": 10, "batch_size": 32, "patience": 3, "learning_rate": 0.01, "des": "Exp", "loss": "MSE", "lradj": "type1", "use_amp": false, "use_gpu": true, "gpu": 0, "use_multi_gpu": false, "devices": "0,1,2,3", "p_hidden_dims": [128, 128], "p_hidden_layers": 2, "use_dtw": false, "augmentation_ratio": 0, "seed": 2, "jitter": false, "scaling": false, "permutation": false, "randompermutation": false, "magwarp": false, "timewarp": false, "windowslice": false, "windowwarp": false, "rotation": false, "spawner": false, "dtwwarp": false, "shapedtwwarp": false, "wdba": false, "discdtw": false, "discsdtw": false, "extra_tag": "", "task_name": "long_term_forecast", "is_training": 1, "model_id": "ECL_96_720", "model": "TEFN", "data": "custom", "root_path": "./dataset/electricity/", "data_path": "electricity.csv", "features": "M", "target": "OT", "freq": "h", "checkpoints": "./out/checkpoints/", "seq_len": 96, "label_len": 48, "pred_len": 720, "seasonal_patterns": "Monthly", "inverse": false, "mask_rate": 0.25, "anomaly_ratio": 0.25, "conv_kernel": null}
1 change: 1 addition & 0 deletions configs/long_term_forecast/ECL_script/TEFN_p96.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"expand": 2, "d_conv": 4, "top_k": 5, "num_kernels": 6, "enc_in": 321, "dec_in": 321, "c_out": 321, "d_model": 256, "n_heads": 8, "e_layers": 0, "d_layers": 1, "d_ff": 512, "moving_avg": 25, "factor": 3, "distil": true, "dropout": 0.1, "embed": "timeF", "activation": "gelu", "output_attention": false, "channel_independence": 0, "decomp_method": "moving_avg", "use_norm": 1, "down_sampling_layers": 0, "down_sampling_window": 1, "down_sampling_method": null, "seg_len": 48, "num_workers": 10, "itr": 1, "train_epochs": 10, "batch_size": 32, "patience": 3, "learning_rate": 0.05, "des": "Exp", "loss": "MSE", "lradj": "type1", "use_amp": false, "use_gpu": true, "gpu": 0, "use_multi_gpu": false, "devices": "0,1,2,3", "p_hidden_dims": [128, 128], "p_hidden_layers": 2, "use_dtw": false, "augmentation_ratio": 0, "seed": 2, "jitter": false, "scaling": false, "permutation": false, "randompermutation": false, "magwarp": false, "timewarp": false, "windowslice": false, "windowwarp": false, "rotation": false, "spawner": false, "dtwwarp": false, "shapedtwwarp": false, "wdba": false, "discdtw": false, "discsdtw": false, "extra_tag": "", "task_name": "long_term_forecast", "is_training": 1, "model_id": "ECL_96_96", "model": "TEFN", "data": "custom", "root_path": "./dataset/electricity/", "data_path": "electricity.csv", "features": "M", "target": "OT", "freq": "h", "checkpoints": "./out/checkpoints/", "seq_len": 96, "label_len": 48, "pred_len": 96, "seasonal_patterns": "Monthly", "inverse": false, "mask_rate": 0.25, "anomaly_ratio": 0.25, "conv_kernel": null}
1 change: 1 addition & 0 deletions configs/long_term_forecast/ETT_script/TEFN_ETTh1_p192.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"expand": 2, "d_conv": 4, "top_k": 5, "num_kernels": 6, "enc_in": 7, "dec_in": 7, "c_out": 7, "d_model": 16, "n_heads": 8, "e_layers": 4, "d_layers": 1, "d_ff": 32, "moving_avg": 25, "factor": 3, "distil": true, "dropout": 0.1, "embed": "timeF", "activation": "gelu", "output_attention": false, "channel_independence": 0, "decomp_method": "moving_avg", "use_norm": 1, "down_sampling_layers": 0, "down_sampling_window": 1, "down_sampling_method": null, "seg_len": 48, "num_workers": 10, "itr": 1, "train_epochs": 10, "batch_size": 32, "patience": 3, "learning_rate": 0.05, "des": "Exp", "loss": "MSE", "lradj": "type1", "use_amp": false, "use_gpu": true, "gpu": 0, "use_multi_gpu": false, "devices": "0,1,2,3", "p_hidden_dims": [128, 128], "p_hidden_layers": 2, "use_dtw": false, "augmentation_ratio": 0, "seed": 2, "jitter": false, "scaling": false, "permutation": false, "randompermutation": false, "magwarp": false, "timewarp": false, "windowslice": false, "windowwarp": false, "rotation": false, "spawner": false, "dtwwarp": false, "shapedtwwarp": false, "wdba": false, "discdtw": false, "discsdtw": false, "extra_tag": "", "task_name": "long_term_forecast", "is_training": 1, "model_id": "ETTh1_96_192", "model": "TEFN", "data": "ETTh1", "root_path": "./dataset/ETT-small/", "data_path": "ETTh1.csv", "features": "M", "target": "OT", "freq": "h", "checkpoints": "./out/checkpoints/", "seq_len": 96, "label_len": 48, "pred_len": 192, "seasonal_patterns": "Monthly", "inverse": false, "mask_rate": 0.25, "anomaly_ratio": 0.25, "conv_kernel": null}
1 change: 1 addition & 0 deletions configs/long_term_forecast/ETT_script/TEFN_ETTh1_p336.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"expand": 2, "d_conv": 4, "top_k": 5, "num_kernels": 6, "enc_in": 7, "dec_in": 7, "c_out": 7, "d_model": 16, "n_heads": 8, "e_layers": 1, "d_layers": 1, "d_ff": 32, "moving_avg": 25, "factor": 3, "distil": true, "dropout": 0.1, "embed": "timeF", "activation": "gelu", "output_attention": false, "channel_independence": 0, "decomp_method": "moving_avg", "use_norm": 1, "down_sampling_layers": 0, "down_sampling_window": 1, "down_sampling_method": null, "seg_len": 48, "num_workers": 10, "itr": 1, "train_epochs": 10, "batch_size": 32, "patience": 3, "learning_rate": 0.01, "des": "Exp", "loss": "MSE", "lradj": "type1", "use_amp": false, "use_gpu": true, "gpu": 0, "use_multi_gpu": false, "devices": "0,1,2,3", "p_hidden_dims": [128, 128], "p_hidden_layers": 2, "use_dtw": false, "augmentation_ratio": 0, "seed": 2, "jitter": false, "scaling": false, "permutation": false, "randompermutation": false, "magwarp": false, "timewarp": false, "windowslice": false, "windowwarp": false, "rotation": false, "spawner": false, "dtwwarp": false, "shapedtwwarp": false, "wdba": false, "discdtw": false, "discsdtw": false, "extra_tag": "", "task_name": "long_term_forecast", "is_training": 1, "model_id": "ETTh1_96_336", "model": "TEFN", "data": "ETTh1", "root_path": "./dataset/ETT-small/", "data_path": "ETTh1.csv", "features": "M", "target": "OT", "freq": "h", "checkpoints": "./out/checkpoints/", "seq_len": 96, "label_len": 48, "pred_len": 336, "seasonal_patterns": "Monthly", "inverse": false, "mask_rate": 0.25, "anomaly_ratio": 0.25, "conv_kernel": null}
1 change: 1 addition & 0 deletions configs/long_term_forecast/ETT_script/TEFN_ETTh1_p720.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"expand": 2, "d_conv": 4, "top_k": 5, "num_kernels": 6, "enc_in": 7, "dec_in": 7, "c_out": 7, "d_model": 16, "n_heads": 8, "e_layers": 1, "d_layers": 1, "d_ff": 32, "moving_avg": 25, "factor": 3, "distil": true, "dropout": 0.1, "embed": "timeF", "activation": "gelu", "output_attention": false, "channel_independence": 0, "decomp_method": "moving_avg", "use_norm": 1, "down_sampling_layers": 0, "down_sampling_window": 1, "down_sampling_method": null, "seg_len": 48, "num_workers": 10, "itr": 1, "train_epochs": 10, "batch_size": 32, "patience": 3, "learning_rate": 0.05, "des": "Exp", "loss": "MSE", "lradj": "type1", "use_amp": false, "use_gpu": true, "gpu": 0, "use_multi_gpu": false, "devices": "0,1,2,3", "p_hidden_dims": [128, 128], "p_hidden_layers": 2, "use_dtw": false, "augmentation_ratio": 0, "seed": 2, "jitter": false, "scaling": false, "permutation": false, "randompermutation": false, "magwarp": false, "timewarp": false, "windowslice": false, "windowwarp": false, "rotation": false, "spawner": false, "dtwwarp": false, "shapedtwwarp": false, "wdba": false, "discdtw": false, "discsdtw": false, "extra_tag": "", "task_name": "long_term_forecast", "is_training": 1, "model_id": "ETTh1_96_720", "model": "TEFN", "data": "ETTh1", "root_path": "./dataset/ETT-small/", "data_path": "ETTh1.csv", "features": "M", "target": "OT", "freq": "h", "checkpoints": "./out/checkpoints/", "seq_len": 96, "label_len": 48, "pred_len": 720, "seasonal_patterns": "Monthly", "inverse": false, "mask_rate": 0.25, "anomaly_ratio": 0.25, "conv_kernel": null}
Loading