trial_outcome.HINT
- class pytrial.tasks.trial_outcome.hint.HINT(disease_embedding_dim=50, protocol_output_dim=50, molecule_embedding_dim=50, global_embed_size=50, highway_num_layer=3, gnn_hidden_size=50, epoch=20, lr=0.0003, batch_size=32, weight_decay=0, prefix_name='phase_I', device='cuda:0')[source]
Implement Hierarchical Interaction Network (HINT) model for clinical trial outcome prediction 1.
- Parameters
disease_embedding_dim (int) – dimension of disease code embedding, e.g., 50
protocol_output_dim (int) – dimension of protocol (eligibility criteria) embedding, e.g., 50
molecule_embedding_dim (int) – dimension of molecule embedding, e.g., 50
global_embed_size (int) – dimension of trial component embedding, e.g., 50
highway_num_layer (int) – number of highway layers, e.g., 3
gnn_hidden_size (int) – dimension of GNN hidden size, e.g., 50
epoch (int) – epoch number during training, e.g., 5
lr (float) – learning rate of optimizer (we use Adam) during training, e.g., 3e-4,
batch_size (int) – batch size during training, e.g., 32
weight_decay (float) – weight decay coefficient, e.g., 0.
prefix_name (str) – name of trial phase as prefix name of the model, e.g., phase_I, phase_II
device (str or torch.device) – Target device to train the model, as cuda:0 or cpu.
Notes
- 1
Fu et al. HINT: Hierarchical Interaction Network for Clinical Trial Outcome Prediction. Cell Patterns, 2022.
- fit(train_data, valid_data=None)[source]
Train HINT model to predict clinical trial outcome (approval rate)
- Parameters
train_data (TrialOutcomeDatasetBase) – Training data, should be a TrialOutcomeDatasetBase object.
valid_data (TrialOutcomeDatasetBase) – Validation data, should be a TrialOutcomeDatasetBase object.
- load_model(checkpoint=None)[source]
Load the learned HINT model from the disk.
- Parameters
checkpoint (str) – The checkpoint folder to load the learned model. The checkpoint under this folder should be model.ckpt.
- predict(test_data)[source]
Make trial outcome prediction for test data.
- Parameters
test_data (TrialOutcomeDatasetBase) – Testing data, should be a TrialOutcomeDatasetBase object.