trial_outcome.XGBoost
- class pytrial.tasks.trial_outcome.xgboost.XGBoost(n_estimators=100, max_depth=3, reg_lambda=0, eval_metric='auc')[source]
Implement XGBoost model for clinical trial outcome prediction.
- Parameters
n_estimators (int) – number of trees in the forest
max_depth (int) – maximum depth of the tree
reg_lambda (float) – L2 regularization term on weights
eval_metric ({'auc', 'logloss'}) – evaluation metric for validation data, default is AUC.
- fit(train_data, valid_data=None)[source]
Train model to predict clinical trial outcomes.
- Parameters
train_data (TrialOutcomeDatasetBase) – Training data, should be a TrialOutcomeDatasetBase object.
valid_data (TrialOutcomeDatasetBase) – Validation data, should be a TrialOutcomeDatasetBase object.
- load_model(checkpoint=None)[source]
Load the learned model from the disk.
- Parameters
checkpoint (str) – The checkpoint folder to load the learned model. The checkpoint under this folder should be model.ckpt.
- predict(test_data)[source]
Make clinical trial outcome predictions.
- Parameters
test_data (TrialOutcomeDatasetBase) – Testing data, should be a TrialOutcomeDatasetBase object.