class pytrial.tasks.trial_simulation.sequence.rnn_gan.RNNGAN(vocab_size, order, max_visit=20, emb_size=64, n_rnn_layer=2, rnn_type='lstm', bidirectional=False, padding_idx=None, learning_rate=0.0001, weight_decay=0.0001, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='trial_simulation.sequence.rnn_gan')[source]

Bases: pytrial.tasks.trial_simulation.sequence.base.SequenceSimulationBase

Implement an RNN based GAN model for longitudinal patient records simulation. The GAN part was proposed by Choi et al. 1.

  • vocab_size (list[int]) – A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication.

  • order (list[str]) – The order of event types in each visits, e.g., ['diag', 'prod', 'med']. Visit = [diag_events, prod_events, med_events], each event is a list of codes.

  • max_visit (int) – The maximum number of visits for input event codes.

  • emb_size (int) – Embedding size for encoding input event codes.

  • n_rnn_layer (int) – Number of RNN layers for encoding historical events.

  • rnn_type (str) – Pick RNN types in [‘rnn’,’lstm’,’gru’]

  • bidirectional (bool) – If True, it encodes historical events in bi-directional manner.

  • padding_idx (int(default=None)) – Set the padding index for input events embedding. If set None, then no padding index will be specified.

  • learning_rate (float) – Learning rate for optimization based on SGD. Use torch.optim.Adam by default.

  • weigth_decay (float) – Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization.

  • batch_size (int) – Batch size when doing SGD optimization.

  • epochs (int) – Maximum number of iterations taken for the solvers to converge.

  • num_worker (int) – Number of workers used to do dataloading during training.

  • device (str) – The model device.



Choi, E., et al. (2017, November). Generating multi-label discrete patient records using generative adversarial networks. In ML4HC (pp. 286-305). PMLR.


Train model with sequential patient records.


train_data (SequencePatientBase) – A SequencePatientBase contains patient records where ‘v’ corresponds to visit sequence of different events.


Load model and the pre-encoded trial embeddings from the given checkpoint dir.


checkpoint (str) –

The input dir that stores the pretrained model.

  • If a directory, the only checkpoint file *.pth.tar will be loaded.

  • If a filepath, will load from this file.

predict(test_data, n=None, n_per_sample=None, return_tensor=True)[source]

Generate synthetic records based on input real patient seq data.

  • test_data (SequencePatientBase) – A SequencePatientBase contains patient records where ‘v’ corresponds to visit sequence of different events.

  • n (int) – How many samples in total will be generated.

  • n_per_sample (int) – How many samples generated based on each indivudals.

  • return_tensor (bool) – If True, return output generated records in tensor format (n, n_visit, n_event), good for later predictive modeling. If False, return records in `SequencePatient format.


Save the learned simulation model to the disk.


output_dir (str or None) – The dir to save the learned model. If set None, will save model to self.checkout_dir.