trial_search.Trial2Vec

class pytrial.tasks.trial_search.models.trial2vec.Trial2Vec(fields=None, ctx_fields=None, tag_field='nct_id', bert_name='emilyalsentzer/Bio_ClinicalBERT', emb_dim=128, logit_scale_init_value=0.07, max_seq_length=128, epochs=10, batch_size=64, learning_rate=2e-05, weight_decay=0.0001, warmup_ratio=0, evaluation_steps=10, num_workers=0, device='cuda:0', use_amp=False, experiment_id='test')[source]

Bases: pytrial.tasks.trial_search.models.base.TrialSearchBase

Implement the Trial2Vec model for trial document similarity search 1.

Parameters
  • fields (list[str]) – A list of fields of documents used as the attribute fields by Trial2Vec model.

  • ctx_fields (list[str]) – A list of fields of documents used as the context fields by Trial2Vec model.

  • tag_field (str) – The tag indicating trial documents, default to be ‘nct_id’.

  • bert_name (str (default='emilyalsentzer/Bio_ClinicalBERT')) – The base transformer-based encoder. Please find model names from the model hub of transformers (https://huggingface.co/models).

  • emb_dim (int, optional (default=768)) – Dimensionality of the embedding vectors.

  • logit_scale_init_value (float, optional (default=0.07)) – The logit scale or the temperature.

  • max_seq_length (int (default=128)) – The maximum length of input tokens for the base encoder.

  • epochs (int, optional (default=10)) – Number of iterations (epochs) over the corpus.

  • batch_size (int, optional (default=64)) – Number of samples in each training batch.

  • learning_rate (float, optional (default=3e-5)) – The learning rate.

  • weight_decay (float, optional (default=1e-4)) – Weight decay applied for regularization.

  • warmup_ratio (float (default=0)) – How many steps used for warmup training. If set 0, not warmup.

  • evaluation_steps (int (default=10)) – How many iterations while we print the training loss and conduct evaluation if evaluator is given.

  • num_workers (int, optional (default=0)) – Use these many worker threads to train the model (=faster training with multicore machines).

  • device (str or torch.device (default='cuda:0')) – The device to put the model on.

  • use_amp (bool (default=False)) – Whether or not use mixed precision training.

  • experiment_id (str, optional (default='test')) – The name of current experiment.

Notes

1

Wang, Z., & Sun, J. (2022). Trial2Vec: Zero-Shot Clinical Trial Document Similarity Search using Self-Supervision. Findings of EMNLP 2022.

encode(inputs, batch_size=None, num_workers=None, return_dict=True, verbose=True)[source]

Encode input documents and output the document embeddings.

Parameters
  • inputs (dict) –

    inputs = {

    ‘x’: pd.DataFrame,

    ’fields’: list[str],

    ’ctx_fields’: list[str],

    ’tag’: str,

    }

    Share the same input format as the train_data in fit function. If fields, ctx_fields, tag are not given, will reuse the ones used during training.

  • batch_size (int, optional) – The batch size when encoding trials.

  • num_workers (int, optional) – The number of workers when building the val dataloader.

  • return_dict (bool) – If set True, return dict[np.ndarray]. Else, return np.ndarray with the order same as the input documents.

  • verbose (bool) – Whether plot progress bar or not.

Returns

embs – Encoded trial-level embeddings with key (tag) and value (embedding)..

Return type

dict[np.ndarray]

evaluate(test_data)[source]

Evaluate within the given trial and corresponding candidate trials.

Parameters

test_data (dict) –

test_data =

{

‘x’: pd.DataFrame,

’y’: pd.DataFrame

}

The provided labeled dataset for test trials. Follow the format listed above.

Returns

results – A dict of metrics and the values.

Return type

dict[float]

Notes

x =

target_trial | trial1 | trial2 | trial3 |
nct01 | nct02 | nct03 | nct04 |

y =

label1 | label2 | label3 |
0 | 0 | 1 |
fit(train_data, valid_data=None)[source]

Train the trial2vec model to get document embeddings for trial search.

Parameters
  • train_data (dict) –

    train_data =

    {

    ‘x’: pd.DataFrame,

    ’fields’: list[str],

    ’ctx_fields’: list[str],

    ’tag’: str,

    }

    Training corpus for the model.

    • x: a dataframe of trial documents.

    • fields: optional, the fields of documents to use for training as key attributes. If not given, the model uses all fields in x.

    • ctx_fields: optional, the fields of documents which belong to context components. If not given, the model will only learn from fields.

    • tag: optional, the field in x that serves as unique identifiers. Typically it is the nct_id of each trial. If not given, the model takes integer tags.

  • valid_data (dict={'x':pd.DataFrame 'y':np.ndarray}.) – Validation data used for identifying the best checkpoint during the training. Need to rewrite the function: get_val_dataloader.

from_pretrained(input_dir=None)[source]

Download pretrained Trial2Vec model.

load_model(checkpoint)[source]

Load model and the pre-encoded trial embeddings from the given checkpoint dir.

Parameters

checkpoint (str) – The input dir that stores the pretrained model. If a directory, the only checkpoint file *.pth.tar will be loaded. If a filepath, will load from this file.

predict(test_data, top_k=10, return_df=True, skip_pretrained=False)[source]

Predict the top-k relevant for input documents.

Parameters
  • test_data (dict) –

    test_data =

    {

    ‘x’: pd.DataFrame,

    ’fields’: list[str],

    ’ctx_fields’: list[str],

    ’tag’: str,

    }

    Share the same input format as the train_data in fit function. If fields, ctx_fields, tag are not given, will reuse the ones used during training.

  • top_k (int) – Number of retrieved candidates.

  • return_df (float) –

    Whether or not return dataframe for the computed similarity ranking.

    • If set True, return (rank, sim);

    • else, return rank_list=[[(doc1,sim1),(doc2,sim2)], [(doc1,sim1),…]].

  • skip_pretrained (bool) – Whether or not skip encoding the trial which has been in the self.trial_embs. If set True, will skip encoding the trial, and get the trial embeddings by lookup from self.trial_embs.

Returns

  • rank (pd.DataFrame) – A dataframe contains the top ranked NCT ids for each.

  • sim (pd.DataFrame) – A dataframe contains the corresponding similarities.

  • rank_list (list[list[tuple]]) – A list of tuples of top ranked docs and similarities.

save_model(output_dir)[source]
Parameters

output_dir (str) – The directory to save the model states.

update_emb(emb_dict)[source]

Update trial embeds: add or modify.

Parameters

emb_dict (dict[np.ndarray]) – The tag and corresponding trial embeddings to updated.