trial_simulation.tabular.CTGAN

class pytrial.tasks.trial_simulation.tabular.ct_gan.CTGAN(embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=0.0002, generator_decay=1e-06, discriminator_lr=0.0002, discriminator_decay=1e-06, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=True, epochs=50, pac=10, cuda=False, experiment_id='trial_simulation.tabular.ctgan')[source]

Bases: pytrial.tasks.trial_simulation.tabular.base.TabularSimulationBase

Implement CTGAN model for patient level tabular data generation 1.

Parameters
  • embedding_dim (int) – Size of the random sample passed to the Generator. Defaults to 128.

  • generator_dim (tuple or list of ints) – Size of the output samples for each one of the Residuals. A Residual Layer will be created for each one of the values provided. Defaults to (256, 256).

  • discriminator_dim (tuple or list of ints) – Size of the output samples for each one of the Discriminator Layers. A Linear Layer will be created for each one of the values provided. Defaults to (256, 256).

  • generator_lr (float) – Learning rate for the generator. Defaults to 2e-4.

  • generator_decay (float) – Generator weight decay for the Adam Optimizer. Defaults to 1e-6.

  • discriminator_lr (float) – Learning rate for the discriminator. Defaults to 2e-4.

  • discriminator_decay (float) – Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6.

  • batch_size (int) – Number of data samples to process in each step.

  • discriminator_steps (int) – Number of discriminator updates to do for each generator update. From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper default is 5. Default used is 1 to match original CTGAN implementation.

  • log_frequency (bool) – Whether to use log frequency of categorical levels in conditional sampling. Defaults to True.

  • verbose (bool) – Whether to have print statements for progress results. Defaults to True.

  • epochs (int) – Number of training epochs. Defaults to 300.

  • pac (int) – Number of samples to group together when applying the discriminator. Defaults to 10.

  • cuda (bool or str) – If True, use CUDA. If a str, use the indicated device. If False, do not use cuda at all.

  • experiment_id (str, optional (default='trial_simulation.tabular.ctgan')) – The name of current experiment. Decide the saved model checkpoint name.

Notes

1

Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019). Modeling tabular data using conditional gan. Advances in Neural Information Processing Systems, 32.

fit(train_data)[source]

Train CTGAN model to simulate tabular patient data.

Parameters

train_data (TabularPatientBase) – The training data.

load_model(checkpoint=None)[source]

Save the learned CTGAN model to the disk.

Parameters

checkpoint (str or None) – If a directory, the only checkpoint file .model will be loaded. If a filepath, will load from this file; If None, will load from self.checkout_dir.

predict(n=200)[source]

simulate a new tabular data with number_of_predictions.

Parameters

n (int) – number of synthetic records going to generate.

Returns

ypred – A new tabular data simulated by the model

Return type

dataset, same as the input dataset

save_model(output_dir=None)[source]

Save the learned CTGAN model to the disk.

Parameters

output_dir (str or None) – The dir to save the learned model. If set None, will save model to self.checkout_dir.