indiv_outcome.tabular.MLP
- class pytrial.tasks.indiv_outcome.tabular.mlp.MLP(input_dim, output_dim, mode, hidden_dim=128, num_layer=2, learning_rate=0.0001, weight_decay=0.0001, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test')[source]
Implement multi-layer perceptron model for tabular individual outcome prediction in clinical trials.
- Parameters
input_dim (int) – Dimension of the input features.
output_dim (int) – Dimension of the outputs. When doing classification, it equals to number of classes.
mode (str) – The task’s objectives, in binary, multiclass, multilabel, or regression
hidden_dim (int) – Hidden dimensions of neural networks.
num_layer (int) – Number of hidden layers.
learning_rate (float) – Learning rate for optimization based on SGD. Use torch.optim.Adam by default.
weigth_decay (float) – Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization.
batch_size (int) – Batch size when doing SGD optimization.
epochs (int) – Maximum number of iterations taken for the solvers to converge.
num_worker (int) – Number of workers used to do dataloading during training.
device (str) – Target device to train the model, as cuda:0 or cpu.
experiment_id (str, optional (default='test')) – The name of current experiment. Decide the saved model checkpoint name.
- fit(train_data, valid_data=None)[source]
Train logistic regression model to predict patient outcome with tabular input data.
- Parameters
train_data (dict) –
{ ‘x’: TabularPatientBase or pd.DataFrame, ‘y’: pd.Series or np.ndarray }
’x’ contain all patient features;
’y’ contain labels for each row.
valid_data (same as train_data.) – Validation data during the training for early stopping.
- load_model(checkpoint)[source]
Load model and the pre-encoded trial embeddings from the given checkpoint dir.
- Parameters
checkpoint (str) – The input dir that stores the pretrained model. - If a directory, the only checkpoint file *.pth.tar will be loaded. - If a filepath, will load from this file.
- predict(test_data)[source]
Make prediction probability based on the learned model.
- Parameters
test_data (Dict or TabularPatientBase or pd.DataFrame or torch.Tensor) –
{‘x’: TabularPatientBase or pd.DataFrame}
’x’ contain all patient features.
- Returns
ypred – Prediction probability for each patient.
For binary classification, return shape (n, );
For multiclass classification, return shape (n, n_class).
- Return type
np.ndarray or torch.Tensor