Source code for pytrial.tasks.indiv_outcome.tabular.transtab

import pdb
from copy import deepcopy

import transtab
import numpy as np
import pandas as pd
import torch

from import TabularPatientBase
from pytrial.utils.check import (
    check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist
from .base import TabularIndivBase, IndivTabDataset

class BuildModel:
    def __new__(self, config):
        contrastive_pretrain = config.pop('contrastive_pretrain')
        if not contrastive_pretrain:
            clf = transtab.build_classifier(**config)
            collate_fn = None
            clf, collate_fn = transtab.build_contrastive_learner(**config)
        return clf, collate_fn

[docs]class TransTab(TabularIndivBase): ''' Implement transtab model for tabular individual outcome prediction in clinical trials [1]_. Parameters ---------- mode: str The task's objectives, in `binary`, `multiclass`. # TODO: `multilabel`, or `regression` Can be ignored if `contrastive_pretrain` is set True. categorical_columns: list a list of categorical feature names. numerical_columns: list a list of numerical feature names. binary_columns: list a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1). contrastive_pretrain: bool(default=False) whether or not take a contrastive pretraining. If set true, `num_class` will be ignored. num_class: int number of output classes to be predicted. hidden_dim: int the dimension of hidden embeddings. num_layer: int the number of transformer layers used in the encoder. num_attention_head: int the numebr of heads of multihead self-attention layer in the transformers. hidden_dropout_prob: float the dropout ratio in the transformer encoder. ffn_dim: int the dimension of feed-forward layer in the transformer layer. activation: str the name of used activation functions, support ``"relu"``, ``"gelu"``, ``"selu"``, ``"leakyrelu"``. learning_rate: float Learning rate for optimization based on SGD. Use torch.optim.Adam by default. weight_decay: float Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization. batch_size: int Batch size when doing SGD optimization. epochs: int Maximum number of iterations taken for the solvers to converge. num_worker: int Number of workers used to do dataloading during training. device: str Target device to train the model, as `cuda:0` or `cpu`. experiment_id: str, optional (default='test') The name of current experiment. Decide the saved model checkpoint name. Notes ----- .. [1] Wang, Z., & Sun, J. (2022). TransTab: Learning Transferable Tabular Transformers Across Tables. NeurIPS'22. ''' def __init__(self, mode=None, categorical_columns=None, numerical_columns=None, binary_columns=None, contrastive_pretrain=False, num_class=2, hidden_dim=128, num_layer=2, num_attention_head=8, hidden_dropout_prob=0, ffn_dim=256, activation='relu', learning_rate=1e-4, weight_decay=1e-4, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test'): super().__init__(experiment_id) mode = mode.lower() if not contrastive_pretrain: assert mode in ['binary', 'multiclass', 'regression', 'multilabel'], 'Must specify `mode` for supervised classifcation.' self.config = { 'categorical_columns':categorical_columns, 'numerical_columns':numerical_columns, 'binary_columns':binary_columns, 'contrastive_pretrain':contrastive_pretrain, 'num_class':num_class, 'hidden_dim':hidden_dim, 'num_layer':num_layer, 'num_attention_head':num_attention_head, 'hidden_dropout_prob':hidden_dropout_prob, 'ffn_dim':ffn_dim, 'activation':activation, 'device':device, 'mode':mode, 'learning_rate':learning_rate, 'weight_decay':weight_decay, 'batch_size':batch_size, 'epochs':epochs, 'num_worker':num_worker, } self._save_config(self.config) self._build_model() self.device = device
[docs] def fit(self, train_data, valid_data=None): '''Train TransTab model to predict patient outcome with tabular input data. Parameters ---------- train_data: list[dict] a list of patient data, each patient is a dict of { 'x': TabularPatientBase or pd.DataFrame, 'y': pd.Series or np.ndarray }. TransTab can learn from multiple different tabular datasets. valid_data: dict Validation data during the training for early stopping. valid_data = { 'x': TabularPatientBase or pd.DataFrame, 'y': pd.Series or np.ndarray } ''' self._input_data_check(train_data) if valid_data is not None: self._input_data_check(valid_data) self._fit_model(train_data=train_data, valid_data=valid_data)
[docs] def predict(self, test_data): ''' Make prediction probability based on the learned model. Parameters ---------- test_data: TabularPatientBase or pd.DataFrame Contain all patient features. Returns ------- ypred: np.ndarray or torch.Tensor - For binary classification, return shape (n, ); - For multiclass classification, return shape (n, n_class). ''' self._input_data_check(test_data) data = self._parse_input_data(test_data) ypred = transtab.predict(self.model, x_test=data) return ypred
[docs] def save_model(self, output_dir=None): ''' Save the learned transtab model to the disk. Parameters ---------- output_dir: str or None The dir to save the learned model. If set None, will save model to `self.checkout_dir`. ''' if output_dir is not None: make_dir_if_not_exist(output_dir) else: output_dir = self.checkout_dir self._save_config(self.config, output_dir=output_dir)
[docs] def load_model(self, checkpoint): ''' Load the learned transtab model from the given checkpoint dir. Parameters ---------- checkpoint: str The input dir that stores the pretrained model. - If a directory, the only checkpoint file `*.pth.tar` will be loaded. - If a filepath, will load from this file. ''' config_filename = check_model_config_file(checkpoint) if config_filename is not None: config = self._load_config(config_filename) self.config.update(config) self.model.load(checkpoint)
[docs] def update(self, config): '''Update the configuration of feature extractor's column map for *cat*, *num*, and *bin* cols. Or update the number of classes for the output classifier layer. Parameters ---------- config: dict a dict of configurations: keys `cat:list`, `num:list`, `bin:list` are to specify the new column names; key `num_class:int` is to specify the number of classes for finetuning on a new dataset. ''' self.model.update(config)
def _build_model(self): config = deepcopy(self.config) self.model, self.collate_fn = BuildModel(config) def _fit_model(self, train_data, valid_data=None): train_data = self._parse_input_data(train_data) if valid_data is not None: valid_data = self._parse_input_data(valid_data) if self.config['mode'] == 'binary': eval_metric = 'auc' elif self.config['mode'] == 'multiclass': eval_metric = 'acc' transtab.train( self.model, train_data, valset=valid_data, eval_metric=eval_metric, output_dir=self.checkout_dir, collate_fn = self.collate_fn, **self.config ) def _parse_input_data(self, inputs): def _check_input(input): if isinstance(input, pd.DataFrame): return input if isinstance(input, TabularPatientBase): return input.df if isinstance(inputs, list): data = [] for input in inputs: x = _check_input(input['x']) data.append((x, input['y'])) if isinstance(inputs, dict): data = (_check_input(inputs['x']), inputs['y']) if isinstance(inputs, pd.DataFrame): data = inputs if isinstance(inputs, TabularPatientBase): data = inputs.df return data def _input_data_check(self, inputs): ''' Check the training / testing data fits the formats. Target to (1) check if inputs valid, if not, give tips about the data problem. Parameters ---------- inputs: [{ 'x': TabularPatientBase or pd.DataFrame, 'y': pd.Series or np.ndarray },...] 'x' contain all patient features; 'y' contain labels for each row. ''' def _check_input(input): if isinstance(input, dict): assert 'x' in input, 'No input patient data found in inputs.' assert isinstance(input['x'], pd.DataFrame) or isinstance(input['x'], TabularPatientBase), 'Get unaccepted input data format, expect `pd.DataFrame` or `TabularPatientBase`, get {} instead.'.format(type(inputs['x'])) if 'y' in input: assert isinstance(input['y'], pd.Series) or isinstance(input['y'], np.ndarray) assert not pd.isnull(input['y']).any(), 'Find NaN in input targets, please check.' if isinstance(input['x'], pd.DataFrame): assert not input['x'].isnull().values.any(), 'Find NaN in input dataframe, please check your input, or try to pass `TabularPatientBase` as inputs.' if isinstance(input['x'], TabularPatientBase): assert not input['x'].df.isnull().values.any(), 'Find NaN in input dataset, please check your input, or try to pass `TabularPatientBase` as inputs.' if isinstance(inputs, list): for input in inputs: _check_input(input) else: _check_input(inputs)