import abc
import pdb
import os
import json
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
from pytrial.utils.check import check_model_dir
from pytrial.utils.check import make_dir_if_not_exist
from pytrial.data.patient_data import SequencePatientBase
from pytrial.data.patient_data import SeqPatientCollator
from ..losses import XentLoss, BinaryXentLoss, MSELoss, MultilabelBinaryXentLoss
[docs]class SequenceIndivBase(abc.ABC):
'''Abstract class for all individual outcome predictions
based on sequential patient data.
Parameters
----------
experiment_id: str, optional (default = 'test')
The name of current experiment.
'''
_mode_list = ['binary','multiclass','multilabel','regression']
training=False
@abc.abstractmethod
def __init__(self, experiment_id='test', mode=None, output_dim=None):
check_model_dir(experiment_id)
self.checkout_dir = os.path.join('./experiments_records', experiment_id,
'checkpoints')
self.result_dir = os.path.join('./experiments_records', experiment_id,
'results')
make_dir_if_not_exist(self.checkout_dir)
make_dir_if_not_exist(self.result_dir)
self._check_mode_and_output_dim(mode, output_dim)
[docs] @abc.abstractmethod
def fit(self, train_data, valid_data):
'''
Fit function needs to be implemented after subclass.
Parameters
----------
train_data: Any
Training data.
valid_data: Any
Validation data.
'''
raise NotImplementedError
[docs] @abc.abstractmethod
def predict(self, test_data):
'''
Prediction function needs to be implemented after subclass.
Parameters
----------
test_data: Any
Testing data.
'''
raise NotImplementedError
[docs] @abc.abstractmethod
def load_model(self, checkpoint):
'''
Load the pretrained model from disk, needs to be implemented after subclass.
Parameters
----------
checkpoint: str
The path to the checkpoint file.
'''
raise NotImplementedError
[docs] @abc.abstractmethod
def save_model(self, output_dir):
'''
Save the model to disk, needs to be implemented after subclass.
Parameters
----------
output_dir: str
The path to the output directory.
'''
raise NotImplementedError
[docs] def train(self, mode=True):
'''
Swith the model to the `training` mode. Work samely as `model.train()` in pytorch.
Parameters
----------
mode: bool, optional (default = True)
If True, switch to the `training` mode.
'''
self.training = mode
self.model.train()
return self
[docs] def eval(self, mode=False):
'''
Swith the model to the `validation` mode. Work samely as `model.eval()` in pytorch.
Parameters
----------
mode: bool, optional (default = False)
If False, switch to the `validation` mode.
'''
self.training = mode
self.model.eval()
return self
def get_test_dataloader(self, test_data):
dataloader = DataLoader(test_data,
batch_size=self.config['batch_size'],
num_workers=self.config['num_worker'],
pin_memory=True,
shuffle=False,
collate_fn=SeqPatientCollator(
config={
'visit_mode':test_data.metadata['visit']['mode'],
'label_mode':test_data.metadata['label']['mode'],
}
),
)
return dataloader
def get_train_dataloader(self, train_data):
dataloader = DataLoader(train_data,
batch_size=self.config['batch_size'],
num_workers=self.config['num_worker'],
pin_memory=False,
shuffle=True,
collate_fn=SeqPatientCollator(
config={
'visit_mode':train_data.metadata['visit']['mode'],
'label_mode':train_data.metadata['label']['mode'],
}
),
)
return dataloader
def _input_data_check(self, inputs):
assert isinstance(inputs, SequencePatientBase), 'Wrong input type.'
def _check_mode_and_output_dim(self, mode, output_dim):
mode = mode.lower()
assert mode in self._mode_list, f'Input mode `{mode}` does not belong to the supported mode list {self._mode_list}.'
if output_dim is None:
if mode not in ['binary','regression']:
raise ValueError('`output_dim` should be given when `mode` is not `binary` or `regression`.')
else:
output_dim = 1
self.mode = mode
self.output_dim = output_dim
def _save_config(self, config, output_dir):
if output_dir is None:
output_dir = self.checkout_dir
temp_path = os.path.join(output_dir, 'config.json')
if os.path.exists(temp_path):
os.remove(temp_path)
with open(temp_path, 'w', encoding='utf-8') as f:
f.write(
json.dumps(config, indent=4)
)
def _load_config(self, checkpoint=None):
'''
Load model config from the given directory.
Parameters
----------
checkpoint: str
The given filepath (e.g., ./checkpoint/config.json)
to load the model config.
'''
if checkpoint is None:
temp_path = os.path.join(checkpoint, 'config.json')
else:
temp_path = checkpoint
assert os.path.exists(temp_path), 'Cannot find `config.json` under {}'.format(self.checkout_dir)
with open(temp_path, 'r') as f:
config = json.load(f)
return config
def _save_checkpoint(self,
state_dict,
epoch_id=0,
is_best=False,
output_dir=None,
filename='checkpoint.pth.tar'
):
if output_dir is None:
output_dir = self.checkout_dir
if epoch_id < 1:
filepath = os.path.join(output_dir, 'latest.' + filename)
elif is_best:
filepath = os.path.join(output_dir, 'best.' + filename)
else:
filepath = os.path.join(self.checkout_dir,
str(epoch_id) + '.' + filename)
torch.save(state_dict, filepath)
@torch.no_grad()
def _predict_on_dataloader(self, test_dataloader):
pred_list, label_list = [], []
for batch in test_dataloader:
inputs = self._prepare_input(batch)
logits = self.model(inputs)
pred_list.append(logits)
if 'y' in batch: label_list.append(batch.pop('y'))
pred = torch.cat(pred_list, dim=0)
if self.config['mode'] in ['binary','multilabel']:
pred = torch.sigmoid(pred)
if self.config['mode'] == 'multiclass':
pred = torch.softmax(pred, dim=1)
pred = pred.cpu().numpy()
label = torch.cat(label_list) if len(label_list) > 0 else None
return {'pred':pred,'label':label}
def _build_loss_model(self):
mode = self.config['mode']
if mode == 'binary':
return [BinaryXentLoss(self.model)]
if mode == 'multiclass':
return [XentLoss(self.model)]
if mode == 'multilabel':
return [MultilabelBinaryXentLoss(self.model)]
if mode == 'regression':
return [MSELoss(self.model)]
def _prepare_input(self, data):
'''
Prepare inputs for sequential patient record predictive models.
Parameters
----------
data: dict[list]
A batch of patient records.
'''
visits = data['v']
feature = data['x']
if not isinstance(feature, torch.Tensor): feature = torch.tensor(feature)
feature = feature.to(self.device)
inputs = {
'v':{},
'v_lengths':[],
'x':feature, # baseline feature
}
if self.config['max_visit'] is None:
v_lengths = [len(visits[self.config['orders'][0]][idx]) for idx in range(len(visits[self.config['orders'][0]]))]
else:
max_visit = self.config['max_visit']
v_lengths = [len(visits[self.config['orders'][0]][idx][:max_visit]) for idx in range(len(visits[self.config['orders'][0]]))]
inputs['v_lengths'] = v_lengths
v = torch.zeros(len(v_lengths), max(v_lengths), self.config['total_vocab_size'])
for idx in range(len(v_lengths)):
v[idx,:v_lengths[idx]] = torch.tensor(self._translate_dense_visits_to_sparse({k: visits[k][idx][:self.config['max_visit']] for k in visits}))
v = v.to(self.device)
inputs['v'] = v
if 'y' in data: # target labels
target = data['y']
if not isinstance(target, torch.Tensor): target = torch.tensor(target)
target = target.to(self.device)
inputs['y'] = target
return inputs
def _translate_dense_visits_to_sparse(self, visits):
total_vocab_size = sum(self.config['vocab_size'])
num_visits = len(visits[self.config['orders'][0]])
outputs = np.zeros((num_visits, total_vocab_size))
for i, o in enumerate(self.config['orders']):
for j in range(num_visits):
raw = visits[o][j]
if isinstance(raw, torch.Tensor): raw = raw.detach().cpu().numpy()
if i > 0:
voc_size = sum(self.config['vocab_size'][:i-1])
if isinstance(raw, list):
raw = [r + voc_size for r in raw]
else:
raw += voc_size
outputs[j, raw] = 1
return outputs
class InputEventEmbedding(nn.Module):
def __init__(self, orders, vocab_size, emb_size, padding_idx) -> None:
super().__init__()
# build input embeddings
emb_dict = {}
for i, order in enumerate(orders):
emb_dict[order] = nn.Embedding(vocab_size[i], embedding_dim=emb_size, padding_idx=padding_idx)
self.embeddings = nn.ModuleDict(emb_dict)
def forward(self, inputs):
emb_list = []
for k, v in inputs['v'].items():
emb = self.embeddings[k](v)
emb_list.append(emb)
embs = torch.cat(emb_list, 1)
return embs
class RNNModel(nn.Module):
RNN_TYPE = {
'rnn':nn.RNN,
'lstm':nn.LSTM,
'gru':nn.GRU,
}
def __init__(self,
rnn_type,
emb_size,
num_layer,
bidirectional,
) -> None:
super().__init__()
self.model = self.RNN_TYPE[rnn_type](
input_size=emb_size,
hidden_size=emb_size,
num_layers=num_layer,
bidirectional=bidirectional,
batch_first=True,
)
self.bidirectional = bidirectional
def forward(self, x):
outputs = self.model(x)[0]
return outputs