Language Modeling Minimal Start

Simple Transformers currently supports 3 pre-training objectives.

  • Masked Language Modeling (MLM) - Used with bert, camembert, distilbert, roberta
  • Causal Language Modeling (CLM) - Used with gpt2, openai-gpt
  • ELECTRA - Used with electra

Because of this, you need to specify the pre-training objective when training or fine-tuning a Language Model. By default, MLM is used. Setting mlm: False in the model args dict will switch the pre-training objective to CLM. Although ELECTRA used its own unique pre-training objective, the inputs to the generator model are masked in the same way as with the other MLM models. Therefore, mlm can be set to True (done by default) in the args dict for ELECTRA models.

Language Model Fine-Tuning

Refer to Language Model Fine-Tuning section in the Language Model Specifics section for details.

Refer to Language Model Data Formats for the correct input data formats.

Fine-Tuning a BERT model (MLM)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import logging

from simpletransformers.language_modeling import (
    LanguageModelingModel,
    LanguageModelingArgs,
)


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

model_args = LanguageModelingArgs()
model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.num_train_epochs = 1
model_args.dataset_type = "simple"

train_file = "data/train.txt"
test_file = "data/test.txt"

model = LanguageModelingModel(
    "bert", "bert-base-cased", args=model_args
)

# Train the model
model.train_model(train_file, eval_file=test_file)

# Evaluate the model
result = model.eval_model(test_file)

Fine-Tuning a GPT-2 model (CLM)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import logging

from simpletransformers.language_modeling import (
    LanguageModelingModel,
    LanguageModelingArgs,
)


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.num_train_epochs = 1
model_args.dataset_type = "simple"
model_args.mlm = False  # mlm must be False for CLM

train_file = "data/train.txt"
test_file = "data/test.txt"

model = LanguageModelingModel(
    "gpt2", "gpt2-medium", args=model_args
)

# Train the model
model.train_model(train_file, eval_file=test_file)

# Evaluate the model
result = model.eval_model(test_file)

Fine-Tuning an ELECTRA model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import logging

from simpletransformers.language_modeling import (
    LanguageModelingModel,
    LanguageModelingArgs,
)


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.num_train_epochs = 1
model_args.dataset_type = "simple"

train_file = "data/train.txt"
test_file = "data/test.txt"

# Google released separate generator/discriminator models
model = LanguageModelingModel(
    "electra",
    "electra",
    generator_name="google/electra-small-generator",
    discriminator_name="google/electra-large-discriminator",
    args=model_args,
)

# Train the model
model.train_model(train_file, eval_file=test_file)

# Evaluate the model
result = model.eval_model(test_file)

Language Model Training From Scratch

Refer to Training a Language Model From Scratch section in the Language Model Specifics section for details.

Refer to Language Model Data Formats for the correct input data formats.

When training a Language Model from scratch, the model_name parameter is set to None. In addition, the train_files argument is required (see here).

Training a BERT model (MLM) from scratch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import logging

from simpletransformers.language_modeling import (
    LanguageModelingModel,
    LanguageModelingArgs,
)


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.num_train_epochs = 1
model_args.dataset_type = "simple"
model_args.vocab_size = 30000

train_file = "data/train.txt"
test_file = "data/test.txt"

model = LanguageModelingModel(
    "bert", None, args=model_args, train_files=train_file
)

# Train the model
model.train_model(train_file, eval_file=test_file)

# Evaluate the model
result = model.eval_model(test_file)

Training a GPT-2 model (CLM) from scratch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import logging

from simpletransformers.language_modeling import (
    LanguageModelingModel,
    LanguageModelingArgs,
)


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.num_train_epochs = 1
model_args.dataset_type = "simple"
model_args.mlm = False  # mlm must be False for CLM
model_args.vocab_size = 30000

train_file = "data/train.txt"
test_file = "data/test.txt"

model = LanguageModelingModel(
    "gpt2", None, args=model_args, train_files=train_file
)

# Train the model
model.train_model(train_file, eval_file=test_file)

# Evaluate the model
result = model.eval_model(test_file)

Training an ELECTRA model from scratch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import logging

from simpletransformers.language_modeling import (
    LanguageModelingModel,
    LanguageModelingArgs,
)


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.num_train_epochs = 1
model_args.dataset_type = "simple"
model_args.vocab_size = 30000

train_file = "data/train.txt"
test_file = "data/test.txt"

model = LanguageModelingModel(
    "electra",
    None,
    args=model_args,
    train_files=train_file
)

# Train the model
model.train_model(train_file, eval_file=test_file)

# Evaluate the model
result = model.eval_model(test_file)

Guides

Updated: