Seq2Seq Specifics
Sequence-to-Sequence models (Seq2SeqModel
) are models where both the input and targets are text sequences. For example, translation and summarization are sequence-to-sequence tasks.
Currently, four main types of Sequence-to-Sequence models are available.
- BART (Summarization)
- MBART (Translation)
- MarianMT (Translation)
- Encoder-Decoder (Generic)
- RAG (Retrieval Augmented Generation - E,g, Question Answering)
Note that these models are not restricted to the specified task. The task is merely given as an example.
Usage Steps
Using a sequence-to-sequence model in Simple Transformers follows the standard pattern.
- Initialize a
Seq2SeqModel
- Train the model with
train_model()
- Evaluate the model with
eval_model()
- Make predictions on (unlabelled) data with
predict()
Note: You must have Faiss (GPU or CPU) installed to use RAG Models. Faiss installation instructions can be found here.
Initializing a Seq2SeqModel
The __init__
arguments for a Seq2SeqModel
are a little different from the common format found in the other models. Please refer here for more information.
Evaluating Generated Sequences
You can evaluate the models’ generated sequences using custom metric functions (including evaluation during training). However, due to the way Seq2Seq outputs are generated, this may be significantly slower than evaluation with other models.
Note: You must set evaluate_generated_text
to True
to evaluate generated sequences.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import logging
import pandas as pd
from simpletransformers.seq2seq import (
Seq2SeqModel,
Seq2SeqArgs,
)
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
train_data = [
[
"Perseus “Percy” Jackson is the main protagonist and the narrator of the Percy Jackson and the Olympians series.",
"Percy is the protagonist of Percy Jackson and the Olympians",
],
[
"Annabeth Chase is one of the main protagonists in Percy Jackson and the Olympians.",
"Annabeth is a protagonist in Percy Jackson and the Olympians.",
],
]
train_df = pd.DataFrame(
train_data, columns=["input_text", "target_text"]
)
eval_data = [
[
"Grover Underwood is a satyr and the Lord of the Wild. He is the satyr who found the demigods Thalia Grace, Nico and Bianca di Angelo, Percy Jackson, Annabeth Chase, and Luke Castellan.",
"Grover is a satyr who found many important demigods.",
],
[
"Thalia Grace is the daughter of Zeus, sister of Jason Grace. After several years as a pine tree on Half-Blood Hill, she got a new job leading the Hunters of Artemis.",
"Thalia is the daughter of Zeus and leader of the Hunters of Artemis.",
],
]
eval_df = pd.DataFrame(
eval_data, columns=["input_text", "target_text"]
)
model_args = Seq2SeqArgs()
model_args.num_train_epochs = 10
model_args.no_save = True
model_args.evaluate_generated_text = True
model_args.evaluate_during_training = True
model_args.evaluate_during_training_verbose = True
# Initialize model
model = Seq2SeqModel(
encoder_decoder_type="bart",
encoder_decoder_name="facebook/bart-large",
args=model_args,
use_cuda=True,
)
def count_matches(labels, preds):
print(labels)
print(preds)
return sum(
[
1 if label == pred else 0
for label, pred in zip(labels, preds)
]
)
# Train the model
model.train_model(
train_df, eval_data=eval_df, matches=count_matches
)
# # Evaluate the model
results = model.eval_model(eval_df)
# Use the model for prediction
print(
model.predict(
[
"Tyson is a Cyclops, a son of Poseidon, and Percy Jackson’s half brother. He is the current general of the Cyclopes army."
]
)
)