trial_simulation.tabular.TVAE
- class pytrial.tasks.trial_simulation.tabular.tvae.TVAE(embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), l2scale=1e-05, batch_size=500, epochs=50, loss_factor=2, cuda=False, experiment_id='trial_simulation.tabular.tvae')[source]
Bases:
pytrial.tasks.trial_simulation.tabular.base.TabularSimulationBase
Implement TVAE model for tabular patient data simulation 1.
- Parameters
embedding_dim (int) – Size of the random sample passed to the Generator. Defaults to 128.
compress_dims (tuple or list[int]) – Size of each hidden layer in the encoder. Defaults to (128, 128).
decompress_dims (tuple or list[int]) – Size of each hidden layer in the decoder. Defaults to (128, 128).
l2scale (int) – Regularization term. Defaults to 1e-5.
batch_size (int) – Number of data samples to process in each step.
epochs (int) – Number of training epochs. Defaults to 300.
loss_factor (int) – Multiplier for the reconstruction error. Defaults to 2.
cuda (bool or str) –
If
True
, use CUDA. If astr
, use the indicated device.If
False
, do not use cuda at all.
experiment_id (str, optional) – The name of current experiment. Decide the saved model checkpoint name.
Notes
- 1
Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019). Modeling tabular data using conditional gan. Advances in Neural Information Processing Systems, 32.
- fit(train_data)[source]
Train TVAE model to simulate patient data with tabular input data.
- Parameters
train_data (TabularPatientBase) – The training data for TVAE model.
- load_model(checkpoint=None)[source]
Load the learned TVAE model from the disk.
- Parameters
checkpoint (str or None) –
The path to the checkpoint file.
If a directory, the only checkpoint file .model will be loaded.
If a filepath, will load from this file;
If None, will load from self.checkout_dir.