site_selection.PolicyGradientEntropy

class pytrial.tasks.site_selection.pgentropy.PolicyGradientEntropy(trial_dim=211, site_dim=124, embedding_dim=64, enrollment_only=True, K=10, lam=1, learning_rate=0.0001, weight_decay=0.0001, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test')[source]

Implement Policy Gradient Entropy model for selecting clinical trial sites based on possibly missing multi-model site features. 1

Parameters
  • trial_dim (list[int]) – Size of the trial representation

  • site_dim (int) – Size of the site representation

  • embedding_dim (int) – Size of all of the modality and other intermediate embeddings

Notes

1

Srinivasa, R. S., Qian, C., Theodorou, B., Spaeder, J., Xiao, C., Glass, L., & Sun, J. (2022). Clinical trial site matching with improved diversity using fair policy learning. arXiv preprint arXiv:2204.06501.

fit(train_data)[source]

Train model with historical trial-site enrollments.

Parameters

train_data (TrialSiteSimple) – A TrialSiteSimple contains trials, sites, and enrollments.

load_model(checkpoint)[source]

Load model and the pre-encoded trial embeddings from the given checkpoint dir.

Parameters

checkpoint (str) – The input dir that stores the pretrained model. If a directory, the only checkpoint file *.pth.tar will be loaded. If a filepath, will load from this file.

predict(test_data)[source]

Make prediction for site selection.

save_model(output_dir)[source]

Save the learned patient-match model to the disk.

Parameters

output_dir (str or None) – The dir to save the learned model. If set None, will save model to self.checkout_dir.