trial_simulation.tabular.TVAE

class pytrial.tasks.trial_simulation.tabular.tvae.TVAE(embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), l2scale=1e-05, batch_size=500, epochs=50, loss_factor=2, cuda=False, experiment_id='trial_simulation.tabular.tvae')[source]

Bases: pytrial.tasks.trial_simulation.tabular.base.TabularSimulationBase

Implement TVAE model for tabular patient data simulation 1.

Parameters
  • embedding_dim (int) – Size of the random sample passed to the Generator. Defaults to 128.

  • compress_dims (tuple or list[int]) – Size of each hidden layer in the encoder. Defaults to (128, 128).

  • decompress_dims (tuple or list[int]) – Size of each hidden layer in the decoder. Defaults to (128, 128).

  • l2scale (int) – Regularization term. Defaults to 1e-5.

  • batch_size (int) – Number of data samples to process in each step.

  • epochs (int) – Number of training epochs. Defaults to 300.

  • loss_factor (int) – Multiplier for the reconstruction error. Defaults to 2.

  • cuda (bool or str) –

    • If True, use CUDA. If a str, use the indicated device.

    • If False, do not use cuda at all.

  • experiment_id (str, optional) – The name of current experiment. Decide the saved model checkpoint name.

Notes

1

Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019). Modeling tabular data using conditional gan. Advances in Neural Information Processing Systems, 32.

fit(train_data)[source]

Train TVAE model to simulate patient data with tabular input data.

Parameters

train_data (TabularPatientBase) – The training data for TVAE model.

load_model(checkpoint=None)[source]

Load the learned TVAE model from the disk.

Parameters

checkpoint (str or None) –

The path to the checkpoint file.

  • If a directory, the only checkpoint file .model will be loaded.

  • If a filepath, will load from this file;

  • If None, will load from self.checkout_dir.

predict(n=200)[source]

simulate a new tabular data with n.

Parameters

n (int) – The number of new data to simulate.

Returns

ypred – A new tabular data simulated by the model

Return type

TanularPatientBase

save_model(output_dir=None)[source]

Save the learned TVAE model to the disk.

Parameters

output_dir (str or None) – The dir to save the learned model. If set None, will save model to self.checkout_dir.