indiv_outcome.tabular.TransTab
- class pytrial.tasks.indiv_outcome.tabular.transtab.TransTab(mode=None, categorical_columns=None, numerical_columns=None, binary_columns=None, contrastive_pretrain=False, num_class=2, hidden_dim=128, num_layer=2, num_attention_head=8, hidden_dropout_prob=0, ffn_dim=256, activation='relu', learning_rate=0.0001, weight_decay=0.0001, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test')[source]
Bases:
pytrial.tasks.indiv_outcome.tabular.base.TabularIndivBase
Implement transtab model for tabular individual outcome prediction in clinical trials 1.
- Parameters
mode (str) – The task’s objectives, in binary, multiclass. # TODO: multilabel, or regression Can be ignored if contrastive_pretrain is set True.
categorical_columns (list) – a list of categorical feature names.
numerical_columns (list) – a list of numerical feature names.
binary_columns (list) – a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1).
contrastive_pretrain (bool(default=False)) – whether or not take a contrastive pretraining. If set true, num_class will be ignored.
num_class (int) – number of output classes to be predicted.
hidden_dim (int) – the dimension of hidden embeddings.
num_layer (int) – the number of transformer layers used in the encoder.
num_attention_head (int) – the numebr of heads of multihead self-attention layer in the transformers.
hidden_dropout_prob (float) – the dropout ratio in the transformer encoder.
ffn_dim (int) – the dimension of feed-forward layer in the transformer layer.
activation (str) – the name of used activation functions, support
"relu"
,"gelu"
,"selu"
,"leakyrelu"
.learning_rate (float) – Learning rate for optimization based on SGD. Use torch.optim.Adam by default.
weight_decay (float) – Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization.
batch_size (int) – Batch size when doing SGD optimization.
epochs (int) – Maximum number of iterations taken for the solvers to converge.
num_worker (int) – Number of workers used to do dataloading during training.
device (str) – Target device to train the model, as cuda:0 or cpu.
experiment_id (str, optional (default='test')) – The name of current experiment. Decide the saved model checkpoint name.
Notes
- 1
Wang, Z., & Sun, J. (2022). TransTab: Learning Transferable Tabular Transformers Across Tables. NeurIPS’22.
- fit(train_data, valid_data=None)[source]
Train TransTab model to predict patient outcome with tabular input data.
- Parameters
train_data (list[dict]) –
a list of patient data, each patient is a dict of {
‘x’: TabularPatientBase or pd.DataFrame,
’y’: pd.Series or np.ndarray
}.
TransTab can learn from multiple different tabular datasets.
valid_data (dict) –
Validation data during the training for early stopping. valid_data =
{
‘x’: TabularPatientBase or pd.DataFrame,
’y’: pd.Series or np.ndarray
}
- load_model(checkpoint)[source]
Load the learned transtab model from the given checkpoint dir.
- Parameters
checkpoint (str) –
The input dir that stores the pretrained model.
If a directory, the only checkpoint file *.pth.tar will be loaded.
If a filepath, will load from this file.
- predict(test_data)[source]
Make prediction probability based on the learned model.
- Parameters
test_data (TabularPatientBase or pd.DataFrame) – Contain all patient features.
- Returns
ypred –
For binary classification, return shape (n, );
For multiclass classification, return shape (n, n_class).
- Return type
np.ndarray or torch.Tensor
- save_model(output_dir=None)[source]
Save the learned transtab model to the disk.
- Parameters
output_dir (str or None) – The dir to save the learned model. If set None, will save model to self.checkout_dir.
- update(config)[source]
Update the configuration of feature extractor’s column map for cat, num, and bin cols. Or update the number of classes for the output classifier layer.
- Parameters
config (dict) – a dict of configurations: keys cat:list, num:list, bin:list are to specify the new column names; key num_class:int is to specify the number of classes for finetuning on a new dataset.