Retrieval models (
RetrievalModel) are models used to retrieve relevant documents from a corpus given a query.
Currently, only DPR models are supported.
Using a retrieval model in Simple Transformers follows the standard pattern.
- Initialize a
- Train the model with
- Evaluate the model with
- Make predictions on (unlabelled) data with
Note: You must have Faiss (GPU or CPU) installed to use RetrievalModel. Faiss installation instructions can be found here.
__init__ arguments for a
RetrievalModel are a little different from the common format found in the other models. Please refer here for more information.
Why hard negatives are needed
In dense passage retrieval, the model is typically trained using the in-batch negatives technique which makes the training process much more computationally efficient. The process is quickly outlined below.
For a batch consisting of query and positive passage pairs:
- Compute the query encodings for each query in the batch.
- Compute the passage encodings for each positive passage in the batch.
- Calculate the cosine similarity between each query and all passages in the batch.
- Optimize for the negative log likelihood of the positive passage for each query.
For more information, refer to the DPR paper.
While this method is computationally efficient, it is not ideal for training a good retrieval model as the negative samples are chosen at random (batches are randomly sampled). The model can be improved further by training with hard negatives, i.e. passages which might be similar but not the same as the positive passage.
Here, the batch would contain triplets of queries, positive passages, and hard negative passages. Each query embedding would then be compared against the embeddings of all positive passages of the other queries (in-batch negatives) as well as all the hard negatives from each query.
How to train with hard negatives
In order to train a
RetrievalModel with hard negatives, the training data must contain a
"hard_negatives" column containing a hard negative example for each query.
Note: You must set
True in the model args in order for the model to include the hard negatives in training. The extra passage per query increases the size of the batch so you may need to decrease the batch size to avoid running out of memory.
The hard negative passages may be obtained by external methods (such as BM25 sparse retrieval). However, Simple Transformers offers a method,
build_hard_negatives(), to generate hard negatives from a given passage dataset. For example, if you are finetuning a DPR model on your own data, you can use the
build_hard_negatives() function to generate hard negatives from your corpus and a pre-trained DPR model.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 import logging import pandas as pd from simpletransformers.retrieval import RetrievalModel, RetrievalArgs logging.basicConfig(level=logging.INFO) transformers_logger = logging.getLogger("transformers") transformers_logger.setLevel(logging.WARNING) queries = [ "Where does solar energy come from?", "What is anthropology a study of?", "In what fields of science is the genome studied?", ] # Note that the passages have been manually truncated for this example. # Typically, you would want to use the full passage. passages = [ "Solar energy is radiant light and heat...", "describes the workings of societies around the world...", "The genome includes both the genes and the non-coding sequences of the DNA/RNA....", "the genome is the genetic material of an organism", "Its main subdivisions are social anthropology and cultural anthropology", "Neptune is the eighth and farthest known planet from the Sun in the Solar System" ] model_type = "dpr" context_name = "facebook/dpr-ctx_encoder-single-nq-base" query_name = "facebook/dpr-question_encoder-single-nq-base" model_args = RetrievalArgs() # Create a TransformerModel model = RetrievalModel( model_type=model_type, context_encoder_name=context_name, query_encoder_name=query_name, args=model_args ) # The hard negatives will be written to the output dir by default. hard_df = model.build_hard_negatives( queries=queries, passage_dataset=passages, retrieve_n_docs=1 ) print(hard_df)
You can combine the hard negatives with the queries and their positive passages to create training data with hard negatives.