trial_simulation.tabular.TabularSimulationBase

class pytrial.tasks.trial_simulation.tabular.base.TabularSimulationBase(experiment_id='trial_simulation.tabular')[source]

Bases: abc.ABC

Abstract class for all tabular simulations based on tabular patient data.

Parameters

experiment_id (str, optional (default = 'test')) – The name of current experiment.

abstract fit(train_data)[source]

Fit the model to the given training data. Need to be implemented by subclass.

Parameters

train_data (Any) – The training data.

abstract load_model(checkpoint)[source]

Load model from the given directory. Need to be implemented by subclass.

Parameters

checkpoint (str) – The given filepath (e.g., ./checkpoint/model.pth)

abstract predict(number_of_predictions)[source]

Generate synthetic data after the model trained. Need to be implemented by subclass.

Parameters

number_of_predictions (Any) – Number of synthetic data to be generated.

abstract save_model(output_dir)[source]

Save the model to the given directory. Need to be implemented by subclass.

Parameters

output_dir (str) – The given directory to save the model.