trial_simulation.sequence.SynTEG

class pytrial.tasks.trial_simulation.sequence.synteg.SynTEG(vocab_size, order, max_visit=20, max_code_per_visit=20, emb_size=64, hidden_dim=32, condition_dim=32, n_head=4, n_rnn_layer=2, z_dim=64, g_dims=[32, 32, 64, 64], d_dims=[64, 32, 32], learning_rate=0.001, batch_size=64, epochs=20, num_worker=0, device='cpu', experiment_id='trial_simulation.sequence.synteg')[source]

Bases: pytrial.tasks.trial_simulation.sequence.base.SequenceSimulationBase

Implement a GAN based model for longitudinal patient records simulation 1.

Parameters
  • vocab_size (list[int]) – A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication.

  • order (list[str]) – The order of event types in each visits, e.g., ['diag', 'prod', 'med']. Visit = [diag_events, prod_events, med_events], each event is a list of codes.

  • max_visit (int) – Maximum number of visits.

  • max_code_per_visit (int) – Maximum number of medical codes in a single visit.

  • emb_size (int) – Embedding size for encoding input event codes.

  • hidden_dim (int) – Size of intermediate hidden dimension for RNN and Feed Forward layers

  • condition_dim (int) – Size of intermediate dimension for encoding medical history to condition the GAN

  • n_head (int) – Number of attention heads

  • n_rnn_layer (int) – Number of RNN layers for encoding historical events.

  • z_dim (int) – Dimension of noise vector passed into the GAN Generator

  • g_dims (list) – List of ints for intermediate GAN Generator dimensionalities

  • d_dims (list) – List of ints for intermediate GAN Discriminator dimensionalities

  • learning_rate (float) – Learning rate for optimization based on SGD. Use torch.optim.Adam by default.

  • batch_size (int) – Batch size when doing SGD optimization.

  • epochs (int) – Maximum number of iterations taken for the solvers to converge.

  • num_worker (int) – Number of workers used to do dataloading during training.

Notes

1

Zhang, Ziqi, et al. (2021, March). SynTEG: a framework for temporal structured electronic health data simulation. Journal of the American Medical Informatics Association 28.3.

fit(train_data)[source]

Train model with sequential patient records.

Parameters

train_data (SequencePatientBase) – A SequencePatientBase contains patient records where ‘v’ corresponds to visit sequence of different events.

load_model(checkpoint)[source]

Load model and the pre-encoded trial embeddings from the given checkpoint dir.

Parameters

checkpoint (str) – The input dir that stores the pretrained model. - If a directory, the only checkpoint file *.pth.tar will be loaded. - If a filepath, will load from this file.

predict(n, return_tensor=True)[source]

Generate synthetic records

Parameters
  • n (int) – How many samples in total will be generated.

  • return_tensor (bool) – If True, return output generated records in tensor format (n, n_visit, n_event), good for later predictive modeling. If False, return records in `SequencePatient format.

save_model(output_dir)[source]

Save the learned simulation model to the disk.

Parameters

output_dir (str or None) – The dir to save the learned model. If set None, will save model to self.checkout_dir.