trial_outcome.HINT

class pytrial.tasks.trial_outcome.hint.HINT(disease_embedding_dim=50, protocol_output_dim=50, molecule_embedding_dim=50, global_embed_size=50, highway_num_layer=3, gnn_hidden_size=50, epoch=20, lr=0.0003, batch_size=32, weight_decay=0, prefix_name='phase_I', device='cuda:0')[source]

Implement Hierarchical Interaction Network (HINT) model for clinical trial outcome prediction 1.

Parameters
  • disease_embedding_dim (int) – dimension of disease code embedding, e.g., 50

  • protocol_output_dim (int) – dimension of protocol (eligibility criteria) embedding, e.g., 50

  • molecule_embedding_dim (int) – dimension of molecule embedding, e.g., 50

  • global_embed_size (int) – dimension of trial component embedding, e.g., 50

  • highway_num_layer (int) – number of highway layers, e.g., 3

  • gnn_hidden_size (int) – dimension of GNN hidden size, e.g., 50

  • epoch (int) – epoch number during training, e.g., 5

  • lr (float) – learning rate of optimizer (we use Adam) during training, e.g., 3e-4,

  • batch_size (int) – batch size during training, e.g., 32

  • weight_decay (float) – weight decay coefficient, e.g., 0.

  • prefix_name (str) – name of trial phase as prefix name of the model, e.g., phase_I, phase_II

  • device (str or torch.device) – Target device to train the model, as cuda:0 or cpu.

Notes

1

Fu et al. HINT: Hierarchical Interaction Network for Clinical Trial Outcome Prediction. Cell Patterns, 2022.

fit(train_data, valid_data=None)[source]

Train HINT model to predict clinical trial outcome (approval rate)

Parameters
load_model(checkpoint=None)[source]

Load the learned HINT model from the disk.

Parameters

checkpoint (str) – The checkpoint folder to load the learned model. The checkpoint under this folder should be model.ckpt.

predict(test_data)[source]

Make trial outcome prediction for test data.

Parameters

test_data (TrialOutcomeDatasetBase) – Testing data, should be a TrialOutcomeDatasetBase object.

save_model(output_dir=None)[source]

Save the learned HINT model to the disk.

Parameters

output_dir (str or None) – The output folder to save the learned model. If set None, will save model to checkpoints/model.ckpt.