trial_simulation.tabular.CTGAN
- class pytrial.tasks.trial_simulation.tabular.ct_gan.CTGAN(embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=0.0002, generator_decay=1e-06, discriminator_lr=0.0002, discriminator_decay=1e-06, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=True, epochs=50, pac=10, cuda=False, experiment_id='trial_simulation.tabular.ctgan')[source]
Bases:
pytrial.tasks.trial_simulation.tabular.base.TabularSimulationBase
Implement CTGAN model for patient level tabular data generation 1.
- Parameters
embedding_dim (int) – Size of the random sample passed to the Generator. Defaults to 128.
generator_dim (tuple or list of ints) – Size of the output samples for each one of the Residuals. A Residual Layer will be created for each one of the values provided. Defaults to (256, 256).
discriminator_dim (tuple or list of ints) – Size of the output samples for each one of the Discriminator Layers. A Linear Layer will be created for each one of the values provided. Defaults to (256, 256).
generator_lr (float) – Learning rate for the generator. Defaults to 2e-4.
generator_decay (float) – Generator weight decay for the Adam Optimizer. Defaults to 1e-6.
discriminator_lr (float) – Learning rate for the discriminator. Defaults to 2e-4.
discriminator_decay (float) – Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6.
batch_size (int) – Number of data samples to process in each step.
discriminator_steps (int) – Number of discriminator updates to do for each generator update. From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper default is 5. Default used is 1 to match original CTGAN implementation.
log_frequency (bool) – Whether to use log frequency of categorical levels in conditional sampling. Defaults to
True
.verbose (bool) – Whether to have print statements for progress results. Defaults to
True
.epochs (int) – Number of training epochs. Defaults to 300.
pac (int) – Number of samples to group together when applying the discriminator. Defaults to 10.
cuda (bool or str) – If
True
, use CUDA. If astr
, use the indicated device. IfFalse
, do not use cuda at all.experiment_id (str, optional (default='trial_simulation.tabular.ctgan')) – The name of current experiment. Decide the saved model checkpoint name.
Notes
- 1
Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019). Modeling tabular data using conditional gan. Advances in Neural Information Processing Systems, 32.
- fit(train_data)[source]
Train CTGAN model to simulate tabular patient data.
- Parameters
train_data (TabularPatientBase) – The training data.
- load_model(checkpoint=None)[source]
Save the learned CTGAN model to the disk.
- Parameters
checkpoint (str or None) – If a directory, the only checkpoint file .model will be loaded. If a filepath, will load from this file; If None, will load from self.checkout_dir.