indiv_outcome.tabular.MLP

class pytrial.tasks.indiv_outcome.tabular.mlp.MLP(input_dim, output_dim, mode, hidden_dim=128, num_layer=2, learning_rate=0.0001, weight_decay=0.0001, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test')[source]

Implement multi-layer perceptron model for tabular individual outcome prediction in clinical trials.

Parameters
  • input_dim (int) – Dimension of the input features.

  • output_dim (int) – Dimension of the outputs. When doing classification, it equals to number of classes.

  • mode (str) – The task’s objectives, in binary, multiclass, multilabel, or regression

  • hidden_dim (int) – Hidden dimensions of neural networks.

  • num_layer (int) – Number of hidden layers.

  • learning_rate (float) – Learning rate for optimization based on SGD. Use torch.optim.Adam by default.

  • weigth_decay (float) – Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization.

  • batch_size (int) – Batch size when doing SGD optimization.

  • epochs (int) – Maximum number of iterations taken for the solvers to converge.

  • num_worker (int) – Number of workers used to do dataloading during training.

  • device (str) – Target device to train the model, as cuda:0 or cpu.

  • experiment_id (str, optional (default='test')) – The name of current experiment. Decide the saved model checkpoint name.

fit(train_data, valid_data=None)[source]

Train logistic regression model to predict patient outcome with tabular input data.

Parameters
  • train_data (dict) –

    { ‘x’: TabularPatientBase or pd.DataFrame, ‘y’: pd.Series or np.ndarray }

    • ’x’ contain all patient features;

    • ’y’ contain labels for each row.

  • valid_data (same as train_data.) – Validation data during the training for early stopping.

load_model(checkpoint)[source]

Load model and the pre-encoded trial embeddings from the given checkpoint dir.

Parameters

checkpoint (str) – The input dir that stores the pretrained model. - If a directory, the only checkpoint file *.pth.tar will be loaded. - If a filepath, will load from this file.

predict(test_data)[source]

Make prediction probability based on the learned model.

Parameters

test_data (Dict or TabularPatientBase or pd.DataFrame or torch.Tensor) –

{‘x’: TabularPatientBase or pd.DataFrame}

’x’ contain all patient features.

Returns

ypred – Prediction probability for each patient.

  • For binary classification, return shape (n, );

  • For multiclass classification, return shape (n, n_class).

Return type

np.ndarray or torch.Tensor

save_model(output_dir=None)[source]

Save the learned logistic regression model to the disk.

Parameters

output_dir (str or None) – The dir to save the learned model. If set None, will save model to self.checkout_dir.