Source code for pytrial.tasks.trial_simulation.tabular.tvae

'''
Implement TVAE model for tabular simulation
prediction in clinical trials.
'''
import os
import warnings
import joblib
import torch
from torch.nn.functional import cross_entropy
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from ctgan import TVAESynthesizer as TVAESynthesizerBase
from ctgan.synthesizers.tvae import Encoder, Decoder

from .base import TabularSimulationBase
from .copula_gan import DataTransformer
from pytrial.data.patient_data import TabularPatientBase
from pytrial.utils.check import check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist

warnings.filterwarnings('ignore')

def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor):
    st = 0
    loss = []
    for column_info in output_info:
        for span_info in column_info:
            if span_info.activation_fn != 'softmax':
                ed = st + span_info.dim
                std = sigmas[st]
                eq = x[:, st] - torch.tanh(recon_x[:, st])
                loss.append((eq ** 2 / 2 / (std ** 2)).sum())
                loss.append(torch.log(std) * x.size()[0])
                st = ed

            else:
                ed = st + span_info.dim
                loss.append(cross_entropy(
                    recon_x[:, st:ed], torch.argmax(x[:, st:ed], dim=-1), reduction='sum'))
                st = ed

    assert st == recon_x.size()[1]
    KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
    return sum(loss) * factor / x.size()[0], KLD / x.size()[0]

class TVAESynthesizer(TVAESynthesizerBase):
    def fit(self, train_data, discrete_columns=()):
        """Fit the TVAE Synthesizer models to the training data.

        Args:
            train_data (numpy.ndarray or pandas.DataFrame):
                Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
            discrete_columns (list-like):
                List of discrete columns to be used to generate the Conditional
                Vector. If ``train_data`` is a Numpy array, this list should
                contain the integer indices of the columns. Otherwise, if it is
                a ``pandas.DataFrame``, this list should contain the column names.
        """
        self.transformer = DataTransformer()
        self.transformer.fit(train_data, discrete_columns)
        train_data = self.transformer.transform(train_data)
        dataset = TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self._device))
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)

        data_dim = self.transformer.output_dimensions
        encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self._device)
        self.decoder = Decoder(self.embedding_dim, self.decompress_dims, data_dim).to(self._device)
        optimizerAE = Adam(
            list(encoder.parameters()) + list(self.decoder.parameters()),
            weight_decay=self.l2scale)

        for i in range(self.epochs):
            for id_, data in enumerate(loader):
                optimizerAE.zero_grad()
                real = data[0].to(self._device)
                mu, std, logvar = encoder(real)
                eps = torch.randn_like(std)
                emb = eps * std + mu
                rec, sigmas = self.decoder(emb)
                loss_1, loss_2 = _loss_function(
                    rec, real, sigmas, mu, logvar,
                    self.transformer.output_info_list, self.loss_factor
                )
                loss = loss_1 + loss_2
                loss.backward()
                optimizerAE.step()
                self.decoder.sigma.data.clamp_(0.01, 1.0)


class BuildModel:
    def __new__(self, config) -> TVAESynthesizer:
        model = TVAESynthesizer(
            embedding_dim=config['embedding_dim'],
            compress_dims=config['compress_dims'],
            decompress_dims=config['decompress_dims'],
            l2scale=config['l2scale'],
            batch_size=config['batch_size'],
            epochs=config['epochs'],
            loss_factor=config['loss_factor'],
            cuda=config['cuda'],
            )

        return model


[docs]class TVAE(TabularSimulationBase): ''' Implement TVAE model for tabular patient data simulation [1]_. Parameters ---------- embedding_dim: int Size of the random sample passed to the Generator. Defaults to 128. compress_dims: tuple or list[int] Size of each hidden layer in the encoder. Defaults to (128, 128). decompress_dims: tuple or list[int] Size of each hidden layer in the decoder. Defaults to (128, 128). l2scale: int Regularization term. Defaults to 1e-5. batch_size: int Number of data samples to process in each step. epochs: int Number of training epochs. Defaults to 300. loss_factor: int Multiplier for the reconstruction error. Defaults to 2. cuda: bool or str - If ``True``, use CUDA. If a ``str``, use the indicated device. - If ``False``, do not use cuda at all. experiment_id: str, optional The name of current experiment. Decide the saved model checkpoint name. Notes ----- .. [1] Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019). Modeling tabular data using conditional gan. Advances in Neural Information Processing Systems, 32. ''' def __init__( self, embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), l2scale=1e-5, batch_size=500, epochs=50, loss_factor=2, cuda=False, experiment_id='trial_simulation.tabular.tvae', ) -> None: super().__init__(experiment_id=experiment_id) self.config = { 'embedding_dim' : embedding_dim, 'compress_dims' : compress_dims, 'decompress_dims' : decompress_dims, 'l2scale' : l2scale, 'batch_size': batch_size, 'epochs': epochs, 'loss_factor' : loss_factor, 'cuda' : cuda, 'experiment_id': experiment_id, 'model_name': 'tvae', } self._save_config(self.config)
[docs] def fit(self, train_data): ''' Train TVAE model to simulate patient data with tabular input data. Parameters ---------- train_data: TabularPatientBase The training data for TVAE model. ''' self._input_data_check(train_data) self._build_model() if isinstance(train_data, TabularPatientBase): # transform=True dataset = train_data.df if isinstance(train_data, dict): dataset = TabularPatientBase(train_data, transform=True) dataset = dataset.df self.metadata = train_data.metadata self.raw_dataset = train_data dataset = self.raw_dataset.reverse_transform() # transform back categoricals = [] fields_before_transform = self.metadata['sdtypes'] for field in dataset.columns: field_name = field.replace('.value', '') if field_name in fields_before_transform: meta = fields_before_transform[field_name] if meta == 'categorical': categoricals.append(field) self._fit_model(dataset, categoricals)
[docs] def predict(self, n=200): ''' simulate a new tabular data with n. Parameters ---------- n: int The number of new data to simulate. Returns ------- ypred: TanularPatientBase A new tabular data simulated by the model ''' ypred = self.model.sample(n) # build df return ypred # output: dataset, same as the input dataset not transform back
[docs] def save_model(self, output_dir=None): ''' Save the learned TVAE model to the disk. Parameters ---------- output_dir: str or None The dir to save the learned model. If set None, will save model to `self.checkout_dir`. ''' if output_dir is not None: make_dir_if_not_exist(output_dir) else: output_dir = self.checkout_dir self._save_config(self.config, output_dir=output_dir) ckpt_path = os.path.join(output_dir, 'tvae.model') joblib.dump(self.model, ckpt_path)
[docs] def load_model(self, checkpoint=None): ''' Load the learned TVAE model from the disk. Parameters ---------- checkpoint: str or None The path to the checkpoint file. - If a directory, the only checkpoint file `.model` will be loaded. - If a filepath, will load from this file; - If None, will load from `self.checkout_dir`. ''' if checkpoint is None: checkpoint = self.checkout_dir checkpoint_filename = check_checkpoint_file(checkpoint, suffix='model') config_filename = check_model_config_file(checkpoint) self.model = joblib.load(checkpoint_filename) self.config = self._load_config(config_filename)
def _build_model(self): self.model = BuildModel(self.config) def _fit_model(self, data, discrete_columns): self.model.fit(data, discrete_columns=discrete_columns)