trial_search.Trial2Vec
- class pytrial.tasks.trial_search.models.trial2vec.Trial2Vec(fields=None, ctx_fields=None, tag_field='nct_id', bert_name='emilyalsentzer/Bio_ClinicalBERT', emb_dim=128, logit_scale_init_value=0.07, max_seq_length=128, epochs=10, batch_size=64, learning_rate=2e-05, weight_decay=0.0001, warmup_ratio=0, evaluation_steps=10, num_workers=0, device='cuda:0', use_amp=False, experiment_id='test')[source]
Bases:
pytrial.tasks.trial_search.models.base.TrialSearchBase
Implement the Trial2Vec model for trial document similarity search 1.
- Parameters
fields (list[str]) – A list of fields of documents used as the attribute fields by Trial2Vec model.
ctx_fields (list[str]) – A list of fields of documents used as the context fields by Trial2Vec model.
tag_field (str) – The tag indicating trial documents, default to be ‘nct_id’.
bert_name (str (default='emilyalsentzer/Bio_ClinicalBERT')) – The base transformer-based encoder. Please find model names from the model hub of transformers (https://huggingface.co/models).
emb_dim (int, optional (default=768)) – Dimensionality of the embedding vectors.
logit_scale_init_value (float, optional (default=0.07)) – The logit scale or the temperature.
max_seq_length (int (default=128)) – The maximum length of input tokens for the base encoder.
epochs (int, optional (default=10)) – Number of iterations (epochs) over the corpus.
batch_size (int, optional (default=64)) – Number of samples in each training batch.
learning_rate (float, optional (default=3e-5)) – The learning rate.
weight_decay (float, optional (default=1e-4)) – Weight decay applied for regularization.
warmup_ratio (float (default=0)) – How many steps used for warmup training. If set 0, not warmup.
evaluation_steps (int (default=10)) – How many iterations while we print the training loss and conduct evaluation if evaluator is given.
num_workers (int, optional (default=0)) – Use these many worker threads to train the model (=faster training with multicore machines).
device (str or torch.device (default='cuda:0')) – The device to put the model on.
use_amp (bool (default=False)) – Whether or not use mixed precision training.
experiment_id (str, optional (default='test')) – The name of current experiment.
Notes
- 1
Wang, Z., & Sun, J. (2022). Trial2Vec: Zero-Shot Clinical Trial Document Similarity Search using Self-Supervision. Findings of EMNLP 2022.
- encode(inputs, batch_size=None, num_workers=None, return_dict=True, verbose=True)[source]
Encode input documents and output the document embeddings.
- Parameters
inputs (dict) –
inputs = {
‘x’: pd.DataFrame,
’fields’: list[str],
’ctx_fields’: list[str],
’tag’: str,
}
Share the same input format as the train_data in fit function. If fields, ctx_fields, tag are not given, will reuse the ones used during training.
batch_size (int, optional) – The batch size when encoding trials.
num_workers (int, optional) – The number of workers when building the val dataloader.
return_dict (bool) – If set True, return dict[np.ndarray]. Else, return np.ndarray with the order same as the input documents.
verbose (bool) – Whether plot progress bar or not.
- Returns
embs – Encoded trial-level embeddings with key (tag) and value (embedding)..
- Return type
dict[np.ndarray]
- evaluate(test_data)[source]
Evaluate within the given trial and corresponding candidate trials.
- Parameters
test_data (dict) –
test_data =
{
‘x’: pd.DataFrame,
’y’: pd.DataFrame
}
The provided labeled dataset for test trials. Follow the format listed above.
- Returns
results – A dict of metrics and the values.
- Return type
dict[float]
Notes
x =
target_trial | trial1 | trial2 | trial3 |nct01 | nct02 | nct03 | nct04 |y =
label1 | label2 | label3 |0 | 0 | 1 |
- fit(train_data, valid_data=None)[source]
Train the trial2vec model to get document embeddings for trial search.
- Parameters
train_data (dict) –
train_data =
{
‘x’: pd.DataFrame,
’fields’: list[str],
’ctx_fields’: list[str],
’tag’: str,
}
Training corpus for the model.
x: a dataframe of trial documents.
fields: optional, the fields of documents to use for training as key attributes. If not given, the model uses all fields in x.
ctx_fields: optional, the fields of documents which belong to context components. If not given, the model will only learn from fields.
tag: optional, the field in x that serves as unique identifiers. Typically it is the nct_id of each trial. If not given, the model takes integer tags.
valid_data (dict={'x':pd.DataFrame 'y':np.ndarray}.) – Validation data used for identifying the best checkpoint during the training. Need to rewrite the function: get_val_dataloader.
- load_model(checkpoint)[source]
Load model and the pre-encoded trial embeddings from the given checkpoint dir.
- Parameters
checkpoint (str) – The input dir that stores the pretrained model. If a directory, the only checkpoint file *.pth.tar will be loaded. If a filepath, will load from this file.
- predict(test_data, top_k=10, return_df=True, skip_pretrained=False)[source]
Predict the top-k relevant for input documents.
- Parameters
test_data (dict) –
test_data =
{
‘x’: pd.DataFrame,
’fields’: list[str],
’ctx_fields’: list[str],
’tag’: str,
}
Share the same input format as the train_data in fit function. If fields, ctx_fields, tag are not given, will reuse the ones used during training.
top_k (int) – Number of retrieved candidates.
return_df (float) –
Whether or not return dataframe for the computed similarity ranking.
If set True, return (rank, sim);
else, return rank_list=[[(doc1,sim1),(doc2,sim2)], [(doc1,sim1),…]].
skip_pretrained (bool) – Whether or not skip encoding the trial which has been in the self.trial_embs. If set True, will skip encoding the trial, and get the trial embeddings by lookup from self.trial_embs.
- Returns
rank (pd.DataFrame) – A dataframe contains the top ranked NCT ids for each.
sim (pd.DataFrame) – A dataframe contains the corresponding similarities.
rank_list (list[list[tuple]]) – A list of tuples of top ranked docs and similarities.