import abc
import os
import json
import pdb
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from pytrial.data.patient_data import TabularPatientBase
from pytrial.utils.check import check_model_dir
from pytrial.utils.check import make_dir_if_not_exist
from pytrial.utils.parallel import batch_to_device
from ..losses import BinaryXentLoss, XentLoss, MSELoss, MultilabelBinaryXentLoss
from ..trainer import IndivTabTrainer
[docs]class TabularIndivBase(abc.ABC):
'''Abstract class for all individual outcome predictions
based on tabular patient data.
Parameters
----------
experiment_id: str, optional (default = 'test')
The name of current experiment.
'''
@abc.abstractmethod
def __init__(self, experiment_id='test'):
check_model_dir(experiment_id)
self.checkout_dir = os.path.join('./experiments_records', experiment_id,
'checkpoints')
self.result_dir = os.path.join('./experiments_records', experiment_id,
'results')
make_dir_if_not_exist(self.checkout_dir)
make_dir_if_not_exist(self.result_dir)
[docs] @abc.abstractmethod
def fit(self, train_data, valid_data):
'''
Fit function needs to be implemented after subclass.
Parameters
----------
train_data: Any
Training data.
valid_data: Any
Validation data.
'''
raise NotImplementedError
[docs] @abc.abstractmethod
def predict(self, test_data):
'''
Prediction function needs to be implemented after subclass.
Parameters
----------
test_data: Any
Testing data.
'''
raise NotImplementedError
[docs] @abc.abstractmethod
def load_model(self, checkpoint):
'''
Load the pretrained model from disk, needs to be implemented after subclass.
Parameters
----------
checkpoint: str
The path to the checkpoint file.
'''
raise NotImplementedError
[docs] @abc.abstractmethod
def save_model(self, output_dir):
'''
Save the model to disk, needs to be implemented after subclass.
Parameters
----------
output_dir: str
The path to the output directory.
'''
raise NotImplementedError
[docs] def train(self, mode=True):
'''Set the model in training mode. Work similar to `model.train()` in PyTorch.
Parameters
----------
mode: bool, optional (default = True)
Whether to set the model in training mode.
``False`` means the model is in evaluation mode.
``True`` means the model is in training mode.
'''
self.training = mode
self.model.train()
return self
[docs] def eval(self, mode=False):
'''Set the model in evaluation mode. Work similar to `model.eval()` in PyTorch.
Parameters
----------
mode: bool, optional (default = False)
Whether to set the model in evaluation mode.
``False`` means the model is in evaluation mode.
``True`` means the model is in training mode.
'''
self.training = mode
self.model.eval()
return self
def get_train_dataloader(self, inputs):
dataset = self._build_dataset(inputs)
dataloader = DataLoader(dataset,
batch_size=self.config['batch_size'],
shuffle=True,
num_workers=self.config['num_worker'],
pin_memory=True,
)
return dataloader
def _build_dataset(self, inputs):
df = inputs['x']
if 'y' in inputs: label = inputs['y']
else: label = None
return IndivTabDataset(df, label)
def _build_loss_model(self):
mode = self.config['mode']
if mode == 'multiclass':
return XentLoss(self.model)
elif mode == 'binary':
return BinaryXentLoss(self.model)
elif mode == 'regression':
return MSELoss(self.model)
elif mode == 'multilabe':
return MultilabelBinaryXentLoss(self.model)
else:
raise ValueError(f'Do not recognize mode `{mode}`, please correct.')
def _fit_model(self, train_data, valid_data=None):
train_dataloader = self.get_train_dataloader(train_data)
loss_model = self._build_loss_model()
train_objectives = [(train_dataloader, loss_model)]
mode = self.config['mode']
test_metric_dict = {
'binary': 'auc',
'multiclass': 'acc',
'regression': 'mse',
'multilabel': 'f1', # take average of F1 scores
}
if mode == 'regression':
less_is_better = True
else:
less_is_better = False
trainer = IndivTabTrainer(
model=self,
train_objectives=train_objectives,
test_data=valid_data,
test_metric=test_metric_dict[mode],
less_is_better=less_is_better,
)
trainer.train(
**self.config
)
def _parse_input_data(self, inputs):
if isinstance(inputs, dict):
if isinstance(inputs['x'], TabularPatientBase):
dataset = inputs['x']
x_feat = dataset.df
y = inputs['y']
else:
x_feat = inputs['x']
y = inputs['y']
if isinstance(y, pd.DataFrame) or isinstance(y, pd.Series):
y = y.values
if isinstance(inputs, pd.DataFrame) or isinstance(inputs, torch.Tensor):
x_feat, y = inputs, None
if isinstance(inputs, TabularPatientBase):
x_feat, y = inputs.df, None
return x_feat, y
def _predict_on_dataloader(self, dataloader):
pred_list, label_list = [], []
for batch in dataloader:
x_feat = batch['x']
x_feat = x_feat.to(self.device)
if 'y' in batch: label_list.append(batch.pop('y'))
pred = self.model(x_feat)
pred_list.append(pred)
pred = torch.cat(pred_list)
label = torch.cat(label_list) if len(label_list) > 0 else None
return {'pred':pred, 'label':label}
def _save_config(self, config, output_dir=None):
if output_dir is None:
output_dir = self.checkout_dir
temp_path = os.path.join(output_dir, 'config.json')
if os.path.exists(temp_path):
os.remove(temp_path)
with open(temp_path, 'w', encoding='utf-8') as f:
f.write(
json.dumps(config, indent=4)
)
def _load_config(self, checkpoint=None):
'''
Load model config from the given directory.
Parameters
----------
checkpoint: str
The given filepath (e.g., ./checkpoint/config.json)
to load the model config.
'''
if checkpoint is None:
temp_path = os.path.join(checkpoint, 'config.json')
else:
temp_path = checkpoint
assert os.path.exists(temp_path), 'Cannot find `config.json` under {}'.format(self.checkout_dir)
with open(temp_path, 'r') as f:
config = json.load(f)
return config
def _save_checkpoint(self, state,
epoch_id=0,
is_best=False,
output_dir=None,
filename='checkpoint.pth.tar'):
if output_dir is None:
output_dir = self.checkout_dir
if epoch_id < 1:
filepath = os.path.join(output_dir, 'latest.' + filename)
elif is_best:
filepath = os.path.join(output_dir, 'best.' + filename)
else:
filepath = os.path.join(self.checkout_dir,
str(epoch_id) + '.' + filename)
torch.save(state, filepath)
def _input_data_check(self, inputs):
'''
Check the training / testing data fits the formats.
Target to (1) check if inputs valid,
if not, give tips about the data problem.
Parameters
----------
inputs: {
'x': TabularPatientBase or pd.DataFrame,
'y': pd.Series or np.ndarray
}
'x' contain all patient features; 'y' contain labels
for each row.
'''
if isinstance(inputs, dict):
assert 'x' in inputs, 'No input patient data found in inputs.'
assert isinstance(inputs['x'], pd.DataFrame) or isinstance(inputs['x'], TabularPatientBase), 'Get unaccepted input data format, expect `pd.DataFrame` or `TabularPatientBase`, get {} instead.'.format(type(inputs['x']))
if 'y' in inputs:
assert isinstance(inputs['y'], pd.Series) or isinstance(inputs['y'], np.ndarray)
assert not pd.isnull(inputs['y']).any(), 'Find NaN in input targets, please check.'
if isinstance(inputs['x'], pd.DataFrame):
assert not inputs['x'].isnull().values.any(), 'Find NaN in input dataframe, please check your input, or try to pass `TabularPatientBase` as inputs.'
if isinstance(inputs['x'], TabularPatientBase):
assert not inputs['x'].df.isnull().values.any(), 'Find NaN in input dataset, please check your input, or try to pass `TabularPatientBase` as inputs.'
class IndivTabDataset(Dataset):
def __init__(self, df, label=None):
self.df = df
self.label = label
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if self.label is not None: return {'x':self.df.iloc[idx].values, 'y': self.label[idx]}
else: return {'x':self.df.iloc[idx].values}