T5 Specifics

The T5 Transformer is an Encoder-Decoder architecture where both the input and targets are text sequences. This gives it the flexibility to perform any Natural Language Processing task without having to modify the model architecture in any way. It also means that the same T5 model can be trained to perform multiple tasks simultaneously.

Please refer to the Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer paper for more details.

Tip: This Medium article explains how to train a T5 Model to perform a new task.

Tip: This Medium article explains how to train a single T5 Model to perform multiple tasks.

Specifying a Task

The T5 model is instructed to perform a particular task by adding a prefix to the start of an input sequence. The prefix for a specific task may be any arbitrary text as long as the same prefix is prepended whenever the model is supposed to execute the given task.

Example prefixes:

  1. binary classification
  2. predict sentiment
  3. answer question

By using multiple, unique prefixes we can train a T5 model to do multiple tasks. During inference, the model will look at the prefix and generate the appropriate output.

Hint: See the T5 Data Formats page for more details on how the inputs and outputs are structured.

Usage Steps

Using a T5 Model in Simple Transformers follows the standard pattern.

  1. Initialize a T5Model
  2. Train the model with train_model()
  3. Evaluate the model with eval_model()
  4. Make predictions on (unlabelled) data with predict()

Supported Model Types

Model Model code for NERModel
T5 t5
MT5 mt5
ByT5 byt5

Tip: The model code is used to specify the model_type in a Simple Transformers model.

Evaluating Generated Sequences

You can evaluate the models’ generated sequences using custom metric functions (including evaluation during training). However, due to the way T5 outputs are generated, this may be significantly slower than evaluation with other models.

Note: You must set evaluate_generated_text to True to evaluate generated sequences.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import logging

import pandas as pd
from simpletransformers.t5 import T5Model, T5Args

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


train_data = [
    ["binary classification", "Anakin was Luke's father" , "1"],
    ["binary classification", "Luke was a Sith Lord" , "0"],
    ["generate question", "Star Wars is an American epic space-opera media franchise created by George Lucas, which began with the eponymous 1977 film and quickly became a worldwide pop-culture phenomenon", "Who created the Star Wars franchise?"],
    ["generate question", "Anakin was Luke's father" , "Who was Luke's father?"],
]
train_df = pd.DataFrame(train_data)
train_df.columns = ["prefix", "input_text", "target_text"]

eval_data = [
    ["binary classification", "Leia was Luke's sister" , "1"],
    ["binary classification", "Han was a Sith Lord" , "0"],
    ["generate question", "In 2020, the Star Wars franchise's total value was estimated at US$70 billion, and it is currently the fifth-highest-grossing media franchise of all time.", "What is the total value of the Star Wars franchise?"],
    ["generate question", "Leia was Luke's sister" , "Who was Luke's sister?"],
]
eval_df = pd.DataFrame(eval_data)
eval_df.columns = ["prefix", "input_text", "target_text"]

model_args = T5Args()
model_args.num_train_epochs = 200
model_args.no_save = True
model_args.evaluate_generated_text = True
model_args.evaluate_during_training = True
model_args.evaluate_during_training_verbose = True

model = T5Model("t5", "t5-base", args=model_args)


def count_matches(labels, preds):
    print(labels)
    print(preds)
    return sum([1 if label == pred else 0 for label, pred in zip(labels, preds)])


model.train_model(train_df, eval_data=eval_df, matches=count_matches)

print(model.eval_model(eval_df, matches=count_matches))

Updated: