trial_outcome.MLP

class pytrial.tasks.trial_outcome.mlp.MLP(epoch=5, lr=0.001, batch_size=32, weight_decay=0.0)[source]

Implement MLP model for clinical trial outcome prediction.

Parameters
  • epoch (int) – number of training epochs.

  • lr (float) – learning rate.

  • weight_decay (float) – Regularization strength for l2 norm; must be a positive float.

fit(train_data, valid_data=None)[source]

Train model to predict clinical trial outcomes.

Parameters
load_model(checkpoint=None)[source]

Load the learned MLP model from the disk.

Parameters

checkpoint (str) – The checkpoint folder to load the learned model. The checkpoint under this folder should be model.ckpt.

predict(test_data)[source]

Make clinical trial outcome predictions.

Parameters

test_data (TrialOutcomeDatasetBase) – Testing data, should be a TrialOutcomeDatasetBase object.

save_model(output_dir=None)[source]

Save the learned model to the disk.

Parameters

output_dir (str or None) – The output folder to save the learned model. If set None, will save model to save_model/model.ckpt.