Source code for pytrial.tasks.trial_simulation.tabular.base

import abc
import os
import json
import pandas as pd

from pytrial.utils.check import check_model_dir
from pytrial.utils.check import make_dir_if_not_exist
from pytrial.data.patient_data import TabularPatientBase


[docs]class TabularSimulationBase(abc.ABC): '''Abstract class for all tabular simulations based on tabular patient data. Parameters ---------- experiment_id: str, optional (default = 'test') The name of current experiment. ''' @abc.abstractmethod def __init__(self, experiment_id='trial_simulation.tabular'): check_model_dir(experiment_id) self.checkout_dir = os.path.join('./experiments_records', experiment_id, 'checkpoints') self.result_dir = os.path.join('./experiments_records', experiment_id, 'results') make_dir_if_not_exist(self.checkout_dir) make_dir_if_not_exist(self.result_dir)
[docs] @abc.abstractmethod def fit(self, train_data): '''Fit the model to the given training data. Need to be implemented by subclass. Parameters ---------- train_data: Any The training data. ''' raise NotImplementedError
[docs] @abc.abstractmethod def predict(self, number_of_predictions): '''Generate synthetic data after the model trained. Need to be implemented by subclass. Parameters ---------- number_of_predictions: Any Number of synthetic data to be generated. ''' raise NotImplementedError
[docs] @abc.abstractmethod def save_model(self, output_dir): '''Save the model to the given directory. Need to be implemented by subclass. Parameters ---------- output_dir: str The given directory to save the model. ''' raise NotImplementedError
[docs] @abc.abstractmethod def load_model(self, checkpoint): ''' Load model from the given directory. Need to be implemented by subclass. Parameters ---------- checkpoint: str The given filepath (e.g., ./checkpoint/model.pth) ''' raise NotImplementedError
def _input_data_check(self, inputs): assert (isinstance(inputs, TabularPatientBase) or isinstance(inputs, dict)), 'Wrong input type.' def _save_config(self, config, output_dir=None): ''' Save model config to the given directory. Parameters ---------- config: dict Hyperparameters and model infomation checkpoint: str The given filepath (e.g., ./checkpoint/config.json) to load the model config. ''' if output_dir is None: output_dir = self.checkout_dir temp_path = os.path.join(output_dir, 'config.json') if os.path.exists(temp_path): os.remove(temp_path) with open(temp_path, 'w', encoding='utf-8') as f: f.write( json.dumps(config, indent=4) ) def _load_config(self, checkpoint=None): ''' Load model config from the given directory. Parameters ---------- checkpoint: str The given filepath (e.g., ./checkpoint/config.json) to load the model config. ''' if checkpoint is None: temp_path = os.path.join(checkpoint, 'config.json') else: temp_path = checkpoint assert os.path.exists(temp_path), 'Cannot find `config.json` under {}'.format(self.checkout_dir) with open(temp_path, 'r') as f: config = json.load(f) return config