trial_outcome.XGBoost

class pytrial.tasks.trial_outcome.xgboost.XGBoost(n_estimators=100, max_depth=3, reg_lambda=0, eval_metric='auc')[source]

Implement XGBoost model for clinical trial outcome prediction.

Parameters
  • n_estimators (int) – number of trees in the forest

  • max_depth (int) – maximum depth of the tree

  • reg_lambda (float) – L2 regularization term on weights

  • eval_metric ({'auc', 'logloss'}) – evaluation metric for validation data, default is AUC.

fit(train_data, valid_data=None)[source]

Train model to predict clinical trial outcomes.

Parameters
load_model(checkpoint=None)[source]

Load the learned model from the disk.

Parameters

checkpoint (str) – The checkpoint folder to load the learned model. The checkpoint under this folder should be model.ckpt.

predict(test_data)[source]

Make clinical trial outcome predictions.

Parameters

test_data (TrialOutcomeDatasetBase) – Testing data, should be a TrialOutcomeDatasetBase object.

save_model(output_dir=None)[source]

Save the learned model to the disk.

Parameters

output_dir (str or None) – The output folder to save the learned model. If set None, will save model to save_model/model.ckpt.