Source code for pytrial.tasks.trial_simulation.tabular.ct_gan

'''
Implement CTGAN model for tabular simulation
prediction in clinical trials.
'''
import os
import warnings
import joblib

from .base import TabularSimulationBase
from .copula_gan import CTGANSynthesizer

from pytrial.data.patient_data import TabularPatientBase
from pytrial.utils.check import check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist

warnings.filterwarnings('ignore')


class BuildModel:
    def __new__(self, config) -> CTGANSynthesizer:
        model = CTGANSynthesizer(
            embedding_dim=config['embedding_dim'],
            generator_dim=config['generator_dim'],
            discriminator_dim=config['discriminator_dim'],
            generator_lr=config['generator_lr'],
            generator_decay=config['generator_decay'],
            discriminator_lr=config['discriminator_lr'],
            discriminator_decay=config['discriminator_decay'],
            batch_size=config['batch_size'],
            discriminator_steps=config['discriminator_steps'],
            log_frequency=config['log_frequency'],
            verbose=config['verbose'],
            epochs=config['epochs'],
            pac=config['pac'],
            cuda=config['cuda'],
            )

        return model


[docs]class CTGAN(TabularSimulationBase): ''' Implement CTGAN model for patient level tabular data generation [1]_. Parameters ---------- embedding_dim: int Size of the random sample passed to the Generator. Defaults to 128. generator_dim: tuple or list of ints Size of the output samples for each one of the Residuals. A Residual Layer will be created for each one of the values provided. Defaults to (256, 256). discriminator_dim: tuple or list of ints Size of the output samples for each one of the Discriminator Layers. A Linear Layer will be created for each one of the values provided. Defaults to (256, 256). generator_lr: float Learning rate for the generator. Defaults to 2e-4. generator_decay: float Generator weight decay for the Adam Optimizer. Defaults to 1e-6. discriminator_lr: float Learning rate for the discriminator. Defaults to 2e-4. discriminator_decay: float Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6. batch_size: int Number of data samples to process in each step. discriminator_steps: int Number of discriminator updates to do for each generator update. From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper default is 5. Default used is 1 to match original CTGAN implementation. log_frequency: bool Whether to use log frequency of categorical levels in conditional sampling. Defaults to ``True``. verbose: bool Whether to have print statements for progress results. Defaults to ``True``. epochs: int Number of training epochs. Defaults to 300. pac: int Number of samples to group together when applying the discriminator. Defaults to 10. cuda: bool or str If ``True``, use CUDA. If a ``str``, use the indicated device. If ``False``, do not use cuda at all. experiment_id: str, optional (default='trial_simulation.tabular.ctgan') The name of current experiment. Decide the saved model checkpoint name. Notes ----- .. [1] Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019). Modeling tabular data using conditional gan. Advances in Neural Information Processing Systems, 32. ''' def __init__( self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=True, epochs=50, pac=10, cuda=False, # can be set to "True" if applicable experiment_id='trial_simulation.tabular.ctgan', ) -> None: super().__init__(experiment_id=experiment_id) self.config = { 'embedding_dim' : embedding_dim, 'generator_dim' : generator_dim, 'discriminator_dim' : discriminator_dim, 'generator_lr' : generator_lr, 'generator_decay' : generator_decay, 'discriminator_lr' : discriminator_lr, 'discriminator_decay' : discriminator_decay, 'batch_size' : batch_size, 'discriminator_steps' : discriminator_steps, 'log_frequency' : log_frequency, 'verbose' : verbose, 'epochs' : epochs, 'pac' : pac, 'cuda' : cuda, 'experiment_id': experiment_id, 'model_name': 'ct_gan', } self._save_config(self.config)
[docs] def fit(self, train_data): ''' Train CTGAN model to simulate tabular patient data. Parameters ---------- train_data: TabularPatientBase The training data. ''' self._input_data_check(train_data) self._build_model() if isinstance(train_data, TabularPatientBase): # transform=True self.metadata = train_data.metadata self.raw_dataset = train_data if isinstance(train_data, dict): tabular_patient_base_dataset = TabularPatientBase(train_data, transform=True) self.metadata = tabular_patient_base_dataset.metadata self.raw_dataset = tabular_patient_base_dataset dataset = self.raw_dataset.reverse_transform() # transform back categoricals = [] fields_names = self.metadata['sdtypes'] for field in dataset.columns: field_name = field.replace('.value', '') if field_name in fields_names: meta = fields_names[field_name] if meta == 'categorical': categoricals.append(field) self._fit_model(dataset, categoricals)
[docs] def predict(self, n=200): ''' simulate a new tabular data with number_of_predictions. Parameters ---------- n: int number of synthetic records going to generate. Returns ------- ypred: dataset, same as the input dataset A new tabular data simulated by the model ''' ypred = self.model.sample(n) # build df return ypred # output: dataset, same as the input dataset, don't need to transform back
[docs] def save_model(self, output_dir=None): ''' Save the learned CTGAN model to the disk. Parameters ---------- output_dir: str or None The dir to save the learned model. If set None, will save model to `self.checkout_dir`. ''' if output_dir is None: output_dir = self.checkout_dir make_dir_if_not_exist(output_dir) self._save_config(self.config, output_dir=output_dir) ckpt_path = os.path.join(output_dir, 'ctgan.model') joblib.dump(self.model, ckpt_path)
[docs] def load_model(self, checkpoint=None): ''' Save the learned CTGAN model to the disk. Parameters ---------- checkpoint: str or None If a directory, the only checkpoint file `.model` will be loaded. If a filepath, will load from this file; If None, will load from `self.checkout_dir`. ''' if checkpoint is None: checkpoint = self.checkout_dir checkpoint_filename = check_checkpoint_file(checkpoint, suffix='model') config_filename = check_model_config_file(checkpoint) self.model = joblib.load(checkpoint_filename) self.config = self._load_config(config_filename)
def _build_model(self): self.model = BuildModel(self.config) def _fit_model(self, data, discrete_columns): self.model.fit(data, discrete_columns=discrete_columns)