trial_simulation.sequence.EVA
- class pytrial.tasks.trial_simulation.sequence.eva.EVA(vocab_size, order, max_visit=20, emb_size=64, latent_dim=32, n_rnn_layer=2, learning_rate=0.001, batch_size=64, epochs=20, num_worker=0, device='cpu', experiment_id='trial_simulation.sequence.eva')[source]
Bases:
pytrial.tasks.trial_simulation.sequence.base.SequenceSimulationBase
Implement a VAE based model for longitudinal patient records simulation 1.
- Parameters
vocab_size (list[int]) – A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication.
order (list[str]) – The order of event types in each visits, e.g.,
['diag', 'prod', 'med']
. Visit = [diag_events, prod_events, med_events], each event is a list of codes.max_visit (int) – Maximum number of visits.
emb_size (int) – Embedding size for encoding input event codes.
latent_dim (int) – Size of final latent dimension between the encoder and decoder
n_rnn_layer (int) – Number of RNN layers for encoding historical events.
learning_rate (float) – Learning rate for optimization based on SGD. Use torch.optim.Adam by default.
batch_size (int) – Batch size when doing SGD optimization.
epochs (int) – Maximum number of iterations taken for the solvers to converge.
num_worker (int) – Number of workers used to do dataloading during training.
Notes
- 1
Biswal, S., et al. (2020, December). EVA: Generating Longitudinal Electronic Health Records Using Conditional Variational Autoencoders.
- fit(train_data)[source]
Train model with sequential patient records.
- Parameters
train_data (SequencePatientBase) – A SequencePatientBase contains patient records where ‘v’ corresponds to visit sequence of different events.
- load_model(checkpoint)[source]
Load model and the pre-encoded trial embeddings from the given checkpoint dir.
- Parameters
checkpoint (str) –
The input dir that stores the pretrained model.
If a directory, the only checkpoint file *.pth.tar will be loaded.
If a filepath, will load from this file.
- predict(n, return_tensor=False)[source]
Generate synthetic records
- Parameters
n (int) – How many samples in total will be generated.
return_tensor (bool) –
If True, return output generated records in tensor format (n, n_visit, n_event), good for later predictive modeling.
If False, return records in `SequencePatient format.