indiv_outcome.tabular.LogisticRegression
- class pytrial.tasks.indiv_outcome.tabular.logistic_regression.LogisticRegression(weight_decay=1, dual=False, epochs=100, experiment_id='test')[source]
Bases:
pytrial.tasks.indiv_outcome.tabular.base.TabularIndivBase
Implement Logistic Regression model for tabular individual outcome prediction in clinical trials. Now only support binary classification.
- Parameters
weigth_decay (float) – Regularization strength for l2 norm; must be a positive float. Like in support vector machines, smaller values specify weaker regularization.
dual (bool) – Dual or primal formulation. Dual formulation is only implemented for l2 penalty with liblinear solver. Prefer dual=False when n_samples > n_features.
epochs (int) – Maximum number of iterations taken for the solvers to converge.
experiment_id (str, optional (default='test')) – The name of current experiment. Decide the saved model checkpoint name.
- fit(train_data, valid_data=None)[source]
Train logistic regression model to predict patient outcome with tabular input data.
- Parameters
train_data (dict) –
{
‘x’: TabularPatientBase or pd.DataFrame,
’y’: pd.Series or np.ndarray
}
’x’ contain all patient features;
’y’ contain labels for each row.
valid_data (Ignored.) – Not used, present heare for API consistency by convention.
- load_model(checkpoint=None)[source]
Save the learned logistic regression model to the disk.
- Parameters
checkpoint (str or None) –
If a directory, the only checkpoint file .model will be loaded.
If a filepath, will load from this file;
If None, will load from self.checkout_dir.
- predict(test_data)[source]
Make prediction probability based on the learned model. Save to self.result_dir.
- Parameters
test_data (dict) –
{
‘x’: TabularPatientBase or pd.DataFrame,
’y’: pd.Series or np.ndarray
}
’x’ contain all patient features;
’y’ contain labels for each row. Ignored for prediction function.
- Returns
ypred – The predicted probability for each patient.
For binary classification, return shape (n, );
For multiclass classification, return shape (n, n_class).
- Return type
np.ndarray