indiv_outcome.tabular.TransTab

class pytrial.tasks.indiv_outcome.tabular.transtab.TransTab(mode=None, categorical_columns=None, numerical_columns=None, binary_columns=None, contrastive_pretrain=False, num_class=2, hidden_dim=128, num_layer=2, num_attention_head=8, hidden_dropout_prob=0, ffn_dim=256, activation='relu', learning_rate=0.0001, weight_decay=0.0001, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test')[source]

Bases: pytrial.tasks.indiv_outcome.tabular.base.TabularIndivBase

Implement transtab model for tabular individual outcome prediction in clinical trials 1.

Parameters
  • mode (str) – The task’s objectives, in binary, multiclass. # TODO: multilabel, or regression Can be ignored if contrastive_pretrain is set True.

  • categorical_columns (list) – a list of categorical feature names.

  • numerical_columns (list) – a list of numerical feature names.

  • binary_columns (list) – a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1).

  • contrastive_pretrain (bool(default=False)) – whether or not take a contrastive pretraining. If set true, num_class will be ignored.

  • num_class (int) – number of output classes to be predicted.

  • hidden_dim (int) – the dimension of hidden embeddings.

  • num_layer (int) – the number of transformer layers used in the encoder.

  • num_attention_head (int) – the numebr of heads of multihead self-attention layer in the transformers.

  • hidden_dropout_prob (float) – the dropout ratio in the transformer encoder.

  • ffn_dim (int) – the dimension of feed-forward layer in the transformer layer.

  • activation (str) – the name of used activation functions, support "relu", "gelu", "selu", "leakyrelu".

  • learning_rate (float) – Learning rate for optimization based on SGD. Use torch.optim.Adam by default.

  • weight_decay (float) – Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization.

  • batch_size (int) – Batch size when doing SGD optimization.

  • epochs (int) – Maximum number of iterations taken for the solvers to converge.

  • num_worker (int) – Number of workers used to do dataloading during training.

  • device (str) – Target device to train the model, as cuda:0 or cpu.

  • experiment_id (str, optional (default='test')) – The name of current experiment. Decide the saved model checkpoint name.

Notes

1

Wang, Z., & Sun, J. (2022). TransTab: Learning Transferable Tabular Transformers Across Tables. NeurIPS’22.

fit(train_data, valid_data=None)[source]

Train TransTab model to predict patient outcome with tabular input data.

Parameters
  • train_data (list[dict]) –

    a list of patient data, each patient is a dict of {

    ‘x’: TabularPatientBase or pd.DataFrame,

    ’y’: pd.Series or np.ndarray

    }.

    TransTab can learn from multiple different tabular datasets.

  • valid_data (dict) –

    Validation data during the training for early stopping. valid_data =

    {

    ‘x’: TabularPatientBase or pd.DataFrame,

    ’y’: pd.Series or np.ndarray

    }

load_model(checkpoint)[source]

Load the learned transtab model from the given checkpoint dir.

Parameters

checkpoint (str) –

The input dir that stores the pretrained model.

  • If a directory, the only checkpoint file *.pth.tar will be loaded.

  • If a filepath, will load from this file.

predict(test_data)[source]

Make prediction probability based on the learned model.

Parameters

test_data (TabularPatientBase or pd.DataFrame) – Contain all patient features.

Returns

ypred

  • For binary classification, return shape (n, );

  • For multiclass classification, return shape (n, n_class).

Return type

np.ndarray or torch.Tensor

save_model(output_dir=None)[source]

Save the learned transtab model to the disk.

Parameters

output_dir (str or None) – The dir to save the learned model. If set None, will save model to self.checkout_dir.

update(config)[source]

Update the configuration of feature extractor’s column map for cat, num, and bin cols. Or update the number of classes for the output classifier layer.

Parameters

config (dict) – a dict of configurations: keys cat:list, num:list, bin:list are to specify the new column names; key num_class:int is to specify the number of classes for finetuning on a new dataset.