trial_simulation.tabular.MedGAN
- class pytrial.tasks.trial_simulation.tabular.med_gan.MedGAN(embedding_dim=128, random_dim=128, generator_dims=(128, 128), discriminator_dims=(256, 128, 1), compress_dims=(), decompress_dims=(), bn_decay=0.99, l2scale=0.001, pretrain_epoch=200, batch_size=1000, epochs=2000, device='cpu', experiment_id='trial_simulation.tabular.medgan', verbose=False)[source]
Bases:
pytrial.tasks.trial_simulation.tabular.base.TabularSimulationBase
Implement MedGAN model for patient level tabular data generation 1.
- Parameters
embedding_dim (int, default 128) – Dimension of embedding layer.
random_dim (int, default 128) – Dimension of random noise.
generator_dims (tuple, default (128, 128)) – Dimension of generator layers.
discriminator_dims (tuple, default (256, 128, 1)) – Dimension of discriminator layers.
compress_dims (tuple, default ()) – Dimension of compressed embedding layer. datadim -> embedding_dim
decompress_dims (tuple, default ()) – Dimension of decompressed embedding layer. embedding_dim -> datadim
bn_decay (float, default 0.99) – Decay rate of batch normalization.
l2scale (float, default 0.001) – L2 regularization scale.
pretrain_epoch (int, default 200) – Number of pretrain epochs.
batch_size (int, default 1000) – Batch size for training.
epochs (int, default 1000) – Number of epochs for training.
experiment_id (str) – Experiment id for logging.
verbose (bool) – Whether to print training information.
Notes
- 1
Choi, E., Biswal, S., Malin, B., Duke, J., Stewart, W. F., & Sun, J. (2017, November). Generating multi-label discrete patient records using generative adversarial networks. In Machine learning for healthcare conference (pp. 286-305). PMLR.
- fit(train_data)[source]
Train MedGAN model to generate synthetic tabular patient data.
- Parameters
train_data (TabularPatientBase) – Training data.
- load_model(checkpoint)[source]
Load model from checkpoint.
- Parameters
checkpoint (str) – Path to checkpoint. If a directory is given, will load the latest checkpoint in the directory. If a filepath is given, will load the checkpoint from the filepath. If set None, will load from default directory self.checkpoint_dir.