-
Notifications
You must be signed in to change notification settings - Fork 449
FastNLP Tutorial
Coet edited this page Sep 16, 2018
·
8 revisions
loader preprocessor Batch
raw dataset ------> 2-D list of strings -------> DataSet -------> data_iterator ------> batch_x
batch_y
data_loader = POSDatasetLoader("./data/pos_tag_data.txt")
train_data = pos_loader.load_lines()
"""
[
[["This", "is", "fast", "NLP"], ["label_1", "label_3", "label_2", "label_1"]],
...
]
"""
p = SeqLabelPreprocess()
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
# type(data_train) == DataSet
# type(data_dev) == DataSet
DataSet
[
Instance(Field_1, Field_2, Field_3, ...),
Instance(Field_1, Field_2, Field_3, ...),
...
]
data_iterator = Batch(data_train, batch_size=16, sampler=RandomSampler(), use_cuda=False)
for batch_x, batch_y in data_iterator:
x = batch_x["word_seq"]
y = network(x)
get_loss(y, batch_y["label_seq"])
from fastNLP.fastnlp import FastNLP
PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/"
nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES)
nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test")
text = ["这是最好的基于深度学习的中文分词系统。",
"大王叫我来巡山。",
"我党多年来致力于改善人民生活水平。"]
results = nlp.run(text)
# [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'), ('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'), ('。', 'S')], [('大', 'B'), ('王', 'E'), ('叫', 'S'), ('我', 'S'), ('来', 'S'), ('巡', 'B'), ('山', 'E'), ('。', 'S')], [('我', 'B'), ('党', 'E'), ('多', 'S'), ('年', 'S'), ('来', 'S'), ('致', 'B'), ('力', 'E'), ('于', 'S'), ('改', 'B'), ('善', 'E'), ('人', 'B'), ('民', 'E'), ('生', 'B'), ('活', 'E'), ('水', 'B'), ('平', 'E'), ('。', 'S')]]
def train_and_test():
# Load config section from config file
trainer_args = ConfigSection()
model_args = ConfigSection()
ConfigLoader().load_config("./data/config", {
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})
# Load data with data loader
data_loader = POSDatasetLoader("./data/pos_tag_data.txt")
train_data = pos_loader.load_lines()
# Preprocessor: 2-D list of strings ----> DataSet
preprocess = SeqLabelPreprocess()
data_train, data_dev = preprocess .run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
model_args["vocab_size"] = preprocess.vocab_size
model_args["num_classes"] = preprocess.num_classes
# Define trainer
trainer = Trainer(
epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"],
validate=trainer_args["validate"],
use_cuda=trainer_args["use_cuda"],
pickle_path=pickle_path,
save_best_dev=trainer_args["save_best_dev"],
model_name=model_name,
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
)
# Define a model
model = SeqLabeling(model_args)
# Start training
trainer.train(model, data_train, data_dev)
print("Training finished!")
# Define Saver and save a model
saver = ModelSaver(os.path.join(pickle_path, model_name))
saver.save_pytorch(model)
print("Model saved!")
del model, trainer, pos_loader
# Define the same model
model = SeqLabeling(model_args)
# Load trained weights into the model
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
print("model loaded!")
# Load test configuration
tester_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args})
# Define a tester
tester = Tester(save_output=False,
save_loss=False,
save_best_dev=False,
batch_size=4,
use_cuda=False,
pickle_path=pickle_path,
model_name="seq_label_in_test.pkl",
print_every_step=1
)
# Start testing
tester.test(model, data_dev)
print(tester.show_metrics())