trial_simulation.tabular.MedGAN

class pytrial.tasks.trial_simulation.tabular.med_gan.MedGAN(embedding_dim=128, random_dim=128, generator_dims=(128, 128), discriminator_dims=(256, 128, 1), compress_dims=(), decompress_dims=(), bn_decay=0.99, l2scale=0.001, pretrain_epoch=200, batch_size=1000, epochs=2000, device='cpu', experiment_id='trial_simulation.tabular.medgan', verbose=False)[source]

Bases: pytrial.tasks.trial_simulation.tabular.base.TabularSimulationBase

Implement MedGAN model for patient level tabular data generation 1.

Parameters
  • embedding_dim (int, default 128) – Dimension of embedding layer.

  • random_dim (int, default 128) – Dimension of random noise.

  • generator_dims (tuple, default (128, 128)) – Dimension of generator layers.

  • discriminator_dims (tuple, default (256, 128, 1)) – Dimension of discriminator layers.

  • compress_dims (tuple, default ()) – Dimension of compressed embedding layer. datadim -> embedding_dim

  • decompress_dims (tuple, default ()) – Dimension of decompressed embedding layer. embedding_dim -> datadim

  • bn_decay (float, default 0.99) – Decay rate of batch normalization.

  • l2scale (float, default 0.001) – L2 regularization scale.

  • pretrain_epoch (int, default 200) – Number of pretrain epochs.

  • batch_size (int, default 1000) – Batch size for training.

  • epochs (int, default 1000) – Number of epochs for training.

  • experiment_id (str) – Experiment id for logging.

  • verbose (bool) – Whether to print training information.

Notes

1

Choi, E., Biswal, S., Malin, B., Duke, J., Stewart, W. F., & Sun, J. (2017, November). Generating multi-label discrete patient records using generative adversarial networks. In Machine learning for healthcare conference (pp. 286-305). PMLR.

fit(train_data)[source]

Train MedGAN model to generate synthetic tabular patient data.

Parameters

train_data (TabularPatientBase) – Training data.

load_model(checkpoint)[source]

Load model from checkpoint.

Parameters

checkpoint (str) – Path to checkpoint. If a directory is given, will load the latest checkpoint in the directory. If a filepath is given, will load the checkpoint from the filepath. If set None, will load from default directory self.checkpoint_dir.

predict(n)[source]

Generate synthetic tabular patient data.

Parameters

n (int) – Number of samples to generate.

Returns

data – Generated synthetic data.

Return type

np.ndarray

save_model(output_dir)[source]

Save model to checkpoint.

Parameters

output_dir (str) – Output directory. If set None, will save to default directory self.checkpoint_dir.