indiv_outcome.tabular.XGBoost
- class pytrial.tasks.indiv_outcome.tabular.xgboost.XGBoost(mode, n_estimators=100, max_depth=8, n_jobs=0, reg_alpha=0, reg_lambda=0, num_class=None, experiment_id='test')[source]
Implement XGBoost model for tabular individual outcome prediction in clinical trials.
- Parameters
n_estimators (int) – Number of boosting rounds.
max_depth (int) – Maximum tree depth for base learners.
mode (str) – The task’s objectives, in binary, multiclass, multilabel, or regression. Do not support early stopping when multilabel.
n_jobs (int) – Number of parallel threads used to run xgboost. When used with other Scikit-Learn algorithms like grid search, you may choose which algorithm to parallelize and balance the threads. Creating thread contention will significantly slow down both algorithms.
reg_alpha (float) – L1 regularization term on weights (xgb’s alpha).
reg_lambda (float) – L2 regularization term on weights (xgb’s lambda).
experiment_id (str, optional (default='test')) – The name of current experiment. Decide the saved model checkpoint name.
- fit(train_data, valid_data=None)[source]
Train logistic regression model to predict patient outcome with tabular input data.
- Parameters
train_data (dict) –
{ ‘x’: TabularPatientBase or pd.DataFrame, ‘y’: pd.Series or np.ndarray }
’x’ contain all patient features;
’y’ contain labels for each row.
valid_data (same as train_data.) – Validation set for early stopping.
- load_model(checkpoint=None)[source]
Load the learned XGBoost model from the disk.
- Parameters
checkpoint (str or None) –
If a directory, the only checkpoint file .model will be loaded.
If a filepath, will load from this file;
If None, will load from self.checkout_dir.
- predict(test_data)[source]
Make prediction probability based on the learned model. Save to self.result_dir.
- Parameters
test_data (dict) –
{ ‘x’: TabularPatientBase or pd.DataFrame, ‘y’: pd.Series or np.ndarray }
’x’ contain all patient features;
’y’ contain labels for each row. Ignored for prediction function.
- Returns
ypred – For binary classification, return shape (n, ); For multiclass classification, return shape (n, n_class).
- Return type
np.ndarray