trial_patient_match.COMPOSE

class pytrial.tasks.trial_patient_match.models.compose.COMPOSE(order, vocab_size, max_visit=20, word_dim=768, conv_dim=128, mem_dim=320, mlp_dim=512, demo_dim=3, margin=0.3, batch_size=512, epochs=50, learning_rate=0.001, weight_decay=0, num_worker=0, device='cuda:0', experiment_id='test')[source]

Bases: pytrial.tasks.trial_patient_match.models.base.PatientTrialMatchBase

Leverage COMPOSE model for patient-trial matching 1.

One patient’s label = [[0,1,2,3], [2,3,4,5]] where the first is a list of indices of inclusion criteria that the patient satisfies. the second is a list of indices of exclusion criteria that the patient does not satisfy. for the other criteria we dont know if the patient satisfied or not. In this regard, we need to: (1) predict “match” for all inclusion criteria of a trial for recruiting the patient (2) predict “unmatch” for all exclusion criteria of a trial for recruiting the patient

# prediction logits # 0 is unmatch # 1 is match # 2 is unknown

Parameters
  • order (list[str]) – The order of events within each visit in the patient’s EHR data, e.g., orders=[‘diag’,’med’,’prod’].

  • vocab_size (int) – The vocabular size of each patient’s EHR event types, e.g., diag, med, prod.

  • max_visit (int) – Maximum number of visits to load for EHRs inputs.

  • word_dim (int) – The dimension of input word embeddings for encoding eligibility criteria.

  • conv_dim (int) – The dimension of convolutional layers for processing eligibility criteria embeddings.

  • mem_dim (int) – The hidden dimension of the EHR memory network (encode patient EHRs).

  • mlp_dim (int) – The hidden dimension of the MLP layers in the Query Network.

  • demo_dim (int) – The input dimensions of patient demographic information, e.g., age, gender.

  • margin (float) – The margin when compute nn.CosineEmbeddingLoss on patient embedding and inclusion/exclusion criteria embeddings. Refer to Eq. (12) of the reference paper.

  • epochs (int, optional (default=50)) – Number of iterations (epochs) over the training data.

  • batch_size (int, optional (default=512)) – Number of samples in each training batch.

  • learning_rate (float, optional (default=1e-3)) – The learning rate.

  • weight_decay (float (default=0)) – The weight decay during training.

  • num_worker (int) – Number of workers used to do dataloading during training.

  • device (str) – The model device.

Notes

1

Gao, J., et al. (2020, August). COMPOSE: cross-modal pseudo-siamese network for patient trial matching. In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (pp. 803-812).

fit(train_data, valid_data=None)[source]

Fit patient-trial matching model.

Parameters
  • train_data (dict) –

    A dict contains patient and trial data, respectively.

    {

    ‘patient’: pytrial.tasks.trial_patient_match.data.PatientData,

    ’trial’: pytrial.tasks.trial_patient_match.data.TrialData

    }

  • valid_data (dict) – A dict contains patient and trial data for evaluation. Same format as train_data.

load_model(checkpoint)[source]

Load model and the pre-encoded trial embeddings from the given checkpoint dir.

Parameters

checkpoint (str) – The input dir that stores the pretrained model. If a directory, the only checkpoint file *.pth.tar will be loaded. If a filepath, will load from this file.

predict(test_data, return_dict=False)[source]

Predict the matching of patient and ECs of each target trial. Will make predictions for each patient, go through all included trials. For example, given 1000 patients and 10 trials, will generate predictions as the shape of 1000 x 10 x num_eligibility_criteria.

Parameters
  • test_data (dict) –

    A dict contains patient and trial data, respecitvely.

    {

    ‘patient’: pytrial.tasks.trial_patient_match.data.PatientData,

    ’trial’: pytrial.tasks.trial_patient_match.data.TrialData

    }

  • return_dict (bool) – If return results stored in dictionary, otherwise return tuples of results

save_model(output_dir)[source]

Save the learned patient-match model to the disk.

Parameters

output_dir (str or None) – The dir to save the learned model. If set None, will save model to self.checkout_dir.