import pdb
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn import LSTMCell
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import numpy as np
from pytrial.utils.check import (
check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist
)
from .base import SequenceIndivBase
from ..trainer import IndivSeqTrainer
class RaimExtract(nn.Module):
def __init__(self, input_size, window_size, hidden_size):
super(RaimExtract, self).__init__()
self.input_size = input_size
self.window_size = window_size
self.hidden_size = hidden_size
self.w_h_alpha = nn.Linear(self.hidden_size, self.window_size)
self.w_a_alpha = nn.Linear(self.input_size, 1)
self.w_h_beta = nn.Linear(self.hidden_size, self.input_size)
self.w_a_beta = nn.Linear(self.window_size, 1)
self.activate_func = nn.Tanh()
self.weights_func = nn.Softmax(dim = -1)
def forward(self, input_data, h_t_1):
# shape of input_data : <n_batch, window_size, input_size>
# shape of h_t_1 : <n_batch, hidden_size>
# shape of _alpha_weights : <n_batch, window_size>
_alpha_weights = self.activate_func(self.w_h_alpha(h_t_1) +\
self.w_a_alpha(input_data.reshape(-1, self.input_size)).\
squeeze().\
reshape(-1, self.window_size)
)
_alpha_weights = self.weights_func(_alpha_weights)
# shape of _beta_weights : <n_batch, input_size>
_beta_weights = self.activate_func(self.w_h_beta(h_t_1) +\
self.w_a_beta(input_data.permute(0, 2, 1).\
reshape(-1, self.window_size)).\
squeeze().\
reshape(-1, self.input_size)
)
_beta_weights = self.weights_func(_beta_weights)
# shape of _alpha_weights_v : <n_batch, window_size, 1>
_alpha_weights_v = _alpha_weights.unsqueeze(-1)
# shape of _beta_weights_v : <n_batch, 1, input_size>
_beta_weights_v = _beta_weights.unsqueeze(1)
# shape of weight_A : <n_batch, window_size, input_size>
weight_A = _alpha_weights_v * _beta_weights_v
# shape of output_v : <n_batch, input_size>
output_v = torch.sum(weight_A * input_data, dim = 1)
return output_v
class BuildModel(nn.Module):
def __init__(self,
input_size = None,
window_size = 3,
hidden_size = 16,
output_size = 8,
label_size = 1,
device = None):
super(BuildModel, self).__init__()
assert input_size != None and isinstance(input_size, int), 'fill in correct input_size'
self.input_size = input_size
self.window_size = window_size
self.hidden_size = hidden_size
self.output_size = output_size
self.label_size = label_size
self.rnn_raimf = RaimExtract(input_size, window_size, hidden_size)
self.rnn_unit = LSTMCell(input_size, hidden_size)
self.predict_func = nn.Linear(self.output_size, self.label_size)
self.pad = nn.ConstantPad1d((self.window_size-1, 0), 0.)
self.device = device
def forward(self, input_data):
"""
Parameters
----------
input_data = {
'X': shape (batchsize, n_timestep, n_featdim)
'M': shape (batchsize, n_timestep), the first visit to the current visit are all 1, the later are 0, i.e., [1,1,1,0,0].
'cur_M': shape (batchsize, n_timestep), only the current visit is 1 others are 0, i.e.,[0,0,1,0,0]
'T': shape (batchsize, n_timestep), the time interval between visits, i.e., for timestamp [0, 1, 2, 3, 4], T = [0, 1, 1, 1, 1].
}
Return
----------
all_output, shape (batchsize, n_timestep, n_labels)
predict output of each time step
cur_output, shape (batchsize, n_labels)
predict output of last time step
"""
X = input_data['v']
if 'v_lengths' in input_data:
v_len = input_data['v_lengths']
mask, cur_mask = self._create_mask_from_visit_length(v_len=v_len)
mask = mask.to(X.device)
cur_mask = cur_mask.to(X.device)
else:
mask = cur_mask = None
batchsize, n_timestep, n_f = X .shape
# shape of X : <batchsize, n_timestep, n_featdim>
# shape of pad_X : <batchsize, n_timestep + window_size - 1, n_featdim>
pad_X = self.pad(X.permute(0,2,1)).permute(0,2,1)
h0 = Variable(torch.zeros(batchsize, self.hidden_size)).to(self.device)
c0 = Variable(torch.zeros(batchsize, self.hidden_size)).to(self.device)
outs = []
hn = h0
cn = c0
for i in range(self.window_size-1, n_timestep+self.window_size-1):
cur_x = pad_X[ : , i - self.window_size + 1 : i + 1, : ]
z_value = self.rnn_raimf(cur_x, hn)
hn, cn = self.rnn_unit(z_value, (hn, cn))
outs.append(hn)
outputs = torch.stack(outs, dim =1)
n_batchsize, n_timestep, n_featdim = outputs.shape
all_output = self.predict_func(outputs.reshape(n_batchsize*n_timestep, n_featdim)).\
reshape(n_batchsize, n_timestep, self.label_size) # bs, timestep, 1
if mask is not None and cur_mask is not None:
all_output *= mask.unsqueeze(-1)
cur_output = (all_output * cur_mask.unsqueeze(-1)).sum(dim=1)
else:
cur_output = all_output.sum(1)
# TODO: all_output is the per visit label prediction, currently we only support patient-level prediction.
return cur_output
def _create_mask_from_visit_length(self, v_len):
mask = torch.zeros(len(v_len),max(v_len))
cur_mask = torch.zeros(len(v_len),max(v_len))
for i in range(len(v_len)):
mask[i, :v_len[i]] = 1
cur_mask[i, v_len[i]-1] = 1
return mask, cur_mask
[docs]class RAIM(SequenceIndivBase):
'''
Implment RAIM for longitudinal patient records predictive modeling [1]_.
Parameters
----------
vocab_size: list[int]
A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication.
orders: list[str]
A list of orders when treating inputs events. Should have the same shape of `vocab_size`.
mode: str
Prediction traget in ['binary','multiclass','multilabel','regression'].
output_dim: int
Output dimension of the model.
- If binary classification, output_dim=1;
- If multiclass/multilabel classification, output_dim=n_class;
- If regression, output_dim=1.
max_visit: int
The maximum number of visits for input event codes.
window_size: int
The window size in the recurrent part of RAIM.
hidden_size: int
Embedding size for encoding input event codes.
hidden_output_size: int
The output size of the intermediate layers. Not the output_dim.
dropout: float
Dropout rate in the intermediate layers.
learning_rate: float
Learning rate for optimization based on SGD. Use torch.optim.Adam by default.
weight_decay: float
Regularization strength for l2 norm; must be a positive float.
Smaller values specify weaker regularization.
batch_size: int
Batch size when doing SGD optimization.
epochs: int
Maximum number of iterations taken for the solvers to converge.
num_worker: int
Number of workers used to do dataloading during training.
device: str
The model device.
experiment_id: str
The prefix when saving the checkpoints during the training.
Notes
-----
.. [1] Xu, Y., Biswal, S., Deshpande, S. R., Maher, K. O., & Sun, J. (2018, July).
Raim: Recurrent attentive and intensive model of multimodal patient monitoring data.
In Proceedings of the 24th ACM SIGKDD international conference on Knowledge Discovery & Data Mining (pp. 2565-2573).
'''
def __init__(self,
vocab_size,
orders,
mode=None,
output_dim=None,
max_visit=None,
window_size=3,
hidden_size=8,
hidden_output_size=8,
dropout=0.5,
learning_rate=1e-4,
weight_decay=1e-4,
batch_size=64,
epochs=10,
num_worker=0,
device='cuda:0',
experiment_id='test',
):
super().__init__(experiment_id, mode, output_dim)
self.config = {
'mode':self.mode,
'vocab_size':vocab_size,
'max_visit':max_visit,
'window_size':window_size,
'hidden_size':hidden_size,
'hidden_output_size':hidden_output_size,
'output_dim':self.output_dim,
'dropout':dropout,
'device':device,
'learning_rate':learning_rate,
'batch_size':batch_size,
'weight_decay':weight_decay,
'epochs':epochs,
'num_worker':num_worker,
'orders':orders,
}
self.config['total_vocab_size'] = sum(vocab_size)
self.device = device
self._build_model()
[docs] def load_model(self, checkpoint):
'''
Load pretrained model from the disk.
Parameters
----------
checkpoint: str
The input directory that stores the trained pytorch model and configuration.
'''
checkpoint_filename = check_checkpoint_file(checkpoint)
config_filename = check_model_config_file(checkpoint)
state_dict = torch.load(checkpoint_filename)
if config_filename is not None:
config = self._load_config(config_filename)
self.config.update(config)
self.model.load_state_dict(state_dict['model'])
[docs] def save_model(self, output_dir):
'''
Save the pretrained model to the disk.
Parameters
----------
output_dir: str
The output directory that stores the trained pytorch model and configuration.
'''
if output_dir is not None:
make_dir_if_not_exist(output_dir)
else:
output_dir = self.checkout_dir
self._save_config(self.config, output_dir=output_dir)
self._save_checkpoint({'model':self.model.state_dict()}, output_dir=output_dir)
[docs] def fit(self, train_data, valid_data=None):
'''
Train model with sequential patient records.
Parameters
----------
train_data: SequencePatientBase
A `SequencePatientBase` contains patient records where 'v' corresponds to
visit sequence of different events; 'y' corresponds to labels.
valid_data: SequencePatientBase
A `SequencePatientBase` contains patient records used to make early stopping of the
model.
'''
self._input_data_check(train_data)
if valid_data is not None: self._input_data_check(valid_data)
self._fit_model(train_data, valid_data=valid_data)
[docs] def predict(self, test_data):
'''
Predict patient outcomes using longitudinal trial patient sequences.
Parameters
----------
test_data: SequencePatient
A `SequencePatient` contains patient records where 'v' corresponds to
visit sequence of different events.
'''
test_dataloader = self.get_test_dataloader(test_data)
outputs = self._predict_on_dataloader(test_dataloader)
return outputs['pred']
def _build_model(self):
self.model = BuildModel(
input_size=self.config['total_vocab_size'],
window_size=self.config['window_size'],
hidden_size=self.config['hidden_size'],
output_size=self.config['hidden_output_size'],
label_size=self.config['output_dim'],
device=self.config['device'],
)
self.model.to(self.device)
def _fit_model(self, train_data, valid_data=None):
test_metric_dict = {
'binary': 'auc',
'multiclass': 'acc',
'regression': 'mse',
'multilabel': 'f1', # take average of F1 scores
}
train_dataloader = self.get_train_dataloader(train_data)
loss_models = self._build_loss_model()
if self.config['mode'] == 'regression':
less_is_better = True
else:
less_is_better = False
train_objectives = [(train_dataloader, loss_model) for loss_model in loss_models]
trainer = IndivSeqTrainer(model=self,
train_objectives=train_objectives,
test_data=valid_data,
test_metric=test_metric_dict[self.config['mode']],
less_is_better=less_is_better,
)
trainer.train(**self.config)