indiv_outcome.tabular.FTTransformer

class pytrial.tasks.indiv_outcome.tabular.ft_transformer.FTTransformer(num_feat, cat_feat, cat_cardinalities, output_dim, mode, hidden_dim=128, num_layer=2, attention_dropout=0, ffn_dim=256, ffn_dropout=0, residual_dropout=0, learning_rate=0.0001, weight_decay=0.0001, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test')[source]

Bases: pytrial.tasks.indiv_outcome.tabular.base.TabularIndivBase

Implement ft-transformer model for tabular individual outcome prediction in clinical trials 1.

Parameters
  • num_feat (list[str]) – the list of numerical feature names.

  • cat_feat (list[str]) – the list of categorical feature names.

  • cat_cardinalities (list[int]) – A list of categorical features’ cardinalities.

  • output_dim (int) – Dimension of the outputs. When doing classification, it equals to number of classes.

  • mode (str) – The task’s objectives, in binary, multiclass, multilabel, or regression

  • hidden_dim (int) – Hidden dimensions of neural networks. Must be a multiple of n_heads=8.

  • num_layer (int) – Number of hidden layers.

  • attention_dropout (float) – the dropout for attention blocks. Usually, positive values work better (even when the number of features is low).

  • ffn_dim (int) – the input size for the second linear layer in Transformer.FFN. Note that it can be different from the output size of the first linear layer, since activations such as ReGLU or GEGLU change the size of input. For example, if ffn_d_hidden=10 and the activation is ReGLU (which is always true for the baseline and default configurations), then the output size of the first linear layer will be set to 20.

  • ffn_dropout (float) – the dropout rate after the first linear layer in Transformer.FFN.

  • residual_dropout (float) – the dropout rate for the output of each residual branch of all Transformer blocks.

  • learning_rate (float) – Learning rate for optimization based on SGD. Use torch.optim.Adam by default.

  • weight_decay (float) – Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization.

  • batch_size (int) – Batch size when doing SGD optimization.

  • epochs (int) – Maximum number of iterations taken for the solvers to converge.

  • num_worker (int) – Number of workers used to do dataloading during training.

  • device (str) – Target device to train the model, as cuda:0 or cpu.

  • experiment_id (str, optional (default='test')) – The name of current experiment. Decide the saved model checkpoint name.

Notes

1

Gorishniy, Y., et al. (2021). Revisiting deep learning models for tabular data. NeurIPS’21.

fit(train_data, valid_data=None)[source]

Train FT-Transformer model to predict patient outcome with tabular input data.

Parameters
  • train_data (dict) –

    { ‘x’: TabularPatientBase or pd.DataFrame, ‘y’: pd.Series or np.ndarray }

    • ’x’ contain all patient features;

    • ’y’ contain label for each row.

  • valid_data (dict) – Same as train_data. Validation data during the training for early stopping.

load_model(checkpoint)[source]

Load the model from the given checkpoint dir.

Parameters

checkpoint (str) –

The input dir that stores the pretrained model.

  • If a directory, the only checkpoint file *.pth.tar will be loaded.

  • If a filepath, will load from this file.

predict(test_data)[source]

Make prediction probability based on the learned model.

Parameters

test_data (Dict or TabularPatientBase or pd.DataFrame or torch.Tensor) –

{‘x’: TabularPatientBase or pd.DataFrame or torch.Tensor}

’x’ contain all patient features.

Returns

ypred – Prediction probability for each patient.

  • For binary classification, return shape (n, );

  • For multiclass classification, return shape (n, n_class).

Return type

np.ndarray or torch.Tensor

save_model(output_dir=None)[source]

Save the learned ft-transformer model to the disk.

Parameters

output_dir (str or None) – The dir to save the learned model. If set None, will save model to self.checkout_dir.