Source code for pytrial.tasks.indiv_outcome.sequence.dipole

import pdb

import torch
from torch import nn

from pytrial.utils.check import (
    check_checkpoint_file, check_model_config_file, make_dir_if_not_exist
)
from .base import SequenceIndivBase
from ..trainer import IndivSeqTrainer

class LocationAttention(nn.Module):

    def __init__(self, hidden_size, device):
        super(LocationAttention, self).__init__()
        self.hidden_size = hidden_size
        self.attention_value_ori_func = nn.Linear(self.hidden_size, 1)
        self.device = device

    def forward(self, input_data):
        # shape of input_data: <n_batch, n_seq, hidden_size>         
        n_batch, n_seq, hidden_size = input_data.shape
        # shape of reshape_feat: <n_batch*n_seq, hidden_size>       
        reshape_feat = input_data.reshape(n_batch*n_seq, hidden_size)
        # shape of attention_value_ori: <n_batch*n_seq, 1>       
        attention_value_ori = torch.exp(self.attention_value_ori_func(reshape_feat))
        # shape of attention_value_format: <n_batch, 1, n_seq>       
        attention_value_format = attention_value_ori.reshape(n_batch, n_seq).unsqueeze(1)        
        # shape of ensemble flag format: <1, n_seq, n_seq> 
        # if n_seq = 3, ensemble_flag_format can get below flag data
        #  [[[ 0  0  0 
        #      1  0  0
        #      1  1  0 ]]]
        ensemble_flag_format = torch.triu(torch.ones([n_seq, n_seq]), diagonal = 1).permute(1, 0).unsqueeze(0).to(self.device)
        # shape of accumulate_attention_value: <n_batch, n_seq, 1>
        accumulate_attention_value = torch.sum(attention_value_format * ensemble_flag_format, -1).unsqueeze(-1) + 1e-9
        # shape of each_attention_value: <n_batch, n_seq, n_seq>
        each_attention_value = attention_value_format * ensemble_flag_format
        # shape of attention_weight_format: <n_batch, n_seq, n_seq>
        attention_weight_format = each_attention_value/accumulate_attention_value
        # shape of _extend_attention_weight_format: <n_batch, n_seq, n_seq, 1>
        _extend_attention_weight_format = attention_weight_format.unsqueeze(-1)
        # shape of _extend_input_data: <n_batch, 1, n_seq, hidden_size>
        _extend_input_data = input_data.unsqueeze(1)
        # shape of _weighted_input_data: <n_batch, n_seq, n_seq, hidden_size>
        _weighted_input_data = _extend_attention_weight_format * _extend_input_data
        # shape of weighted_output: <n_batch, n_seq, hidden_size>
        weighted_output = torch.sum(_weighted_input_data, 2)
        return weighted_output

class GeneralAttention(nn.Module):

    def __init__(self, hidden_size, device):
        super(GeneralAttention, self).__init__()
        self.hidden_size = hidden_size
        self.correlated_value_ori_func = nn.Linear(self.hidden_size, self.hidden_size)
        self.device = device

    def forward(self, input_data):
        # shape of input_data: <n_batch, n_seq, hidden_size>         
        n_batch, n_seq, hidden_size = input_data.shape
        # shape of reshape_feat: <n_batch*n_seq, hidden_size>       
        reshape_feat = input_data.reshape(n_batch*n_seq, hidden_size)
        # shape of correlated_value_ori: <n_batch, n_seq, hidden_size>       
        correlated_value_ori = self.correlated_value_ori_func(reshape_feat).reshape(n_batch, n_seq, hidden_size)
        # shape of _extend_correlated_value_ori: <n_batch, n_seq, 1, hidden_size>   
        _extend_correlated_value_ori = correlated_value_ori.unsqueeze(-2)
        # shape of _extend_input_data: <n_batch, 1, n_seq, hidden_size>   
        _extend_input_data = input_data.unsqueeze(1)
        # shape of _extend_input_data: <n_batch, n_seq, n_seq, hidden_size> 
        _correlat_value = _extend_correlated_value_ori * _extend_input_data
        # shape of attention_value_format: <n_batch, n_seq, n_seq>       
        attention_value_format = torch.exp(torch.sum(_correlat_value, dim = -1))
        # shape of ensemble flag format: <1, n_seq, n_seq> 
        # if n_seq = 3, ensemble_flag_format can get below flag data
        #  [[[ 0  0  0 
        #      1  0  0
        #      1  1  0 ]]]
        ensemble_flag_format = torch.triu(torch.ones([n_seq, n_seq]), diagonal = 1).permute(1, 0).unsqueeze(0).to(self.device)
        # shape of accumulate_attention_value: <n_batch, n_seq, 1>
        accumulate_attention_value = torch.sum(attention_value_format * ensemble_flag_format, -1).unsqueeze(-1) + 1e-10
        # shape of each_attention_value: <n_batch, n_seq, n_seq>
        each_attention_value = attention_value_format * ensemble_flag_format
        # shape of attention_weight_format: <n_batch, n_seq, n_seq>
        attention_weight_format = each_attention_value/accumulate_attention_value
        # shape of _extend_attention_weight_format: <n_batch, n_seq, n_seq, 1>
        _extend_attention_weight_format = attention_weight_format.unsqueeze(-1)
        # shape of _extend_input_data: <n_batch, 1, n_seq, hidden_size>
        _extend_input_data = input_data.unsqueeze(1)
        # shape of _weighted_input_data: <n_batch, n_seq, n_seq, hidden_size>
        _weighted_input_data = _extend_attention_weight_format * _extend_input_data
        # shape of weighted_output: <n_batch, n_seq, hidden_size>
        weighted_output = torch.sum(_weighted_input_data, 2)
        return weighted_output

class ConcatenationAttention(nn.Module):
    def __init__(self, hidden_size, attention_dim = 16, device = None):
        super(ConcatenationAttention, self).__init__()
        self.hidden_size = hidden_size
        self.attention_dim = attention_dim        
        self.attention_map_func = nn.Linear(2 * self.hidden_size, self.attention_dim)
        self.activate_func = nn.Tanh()
        self.correlated_value_ori_func = nn.Linear(self.attention_dim, 1)
        self.device = device
				
    def forward(self, input_data):
        # shape of input_data: <n_batch, n_seq, hidden_size>         
        n_batch, n_seq, hidden_size = input_data.shape
        # shape of _extend_input_data: <n_batch, n_seq, 1, hidden_size>       
        _extend_input_data_f = input_data.unsqueeze(-2)
        # shape of _repeat_extend_correlated_value_ori: <n_batch, n_seq, n_seq, hidden_size>   
        _repeat_extend_input_data_f = _extend_input_data_f.repeat(1,1,n_seq,1)
        # shape of _extend_input_data: <n_batch, 1, n_seq, hidden_size>   
        _extend_input_data_b = input_data.unsqueeze(1)
        # shape of _repeat_extend_input_data: <n_batch, n_seq, n_seq, hidden_size>   
        _repeat_extend_input_data_b = _extend_input_data_b.repeat(1,n_seq,1,1)
        # shape of _concate_value: <n_batch, n_seq, n_seq, 2 * hidden_size>           
        _concate_value = torch.cat([_repeat_extend_input_data_f, _repeat_extend_input_data_b], dim = -1)        
        # shape of _correlat_value: <n_batch, n_seq, n_seq> 
        _correlat_value = self.activate_func(self.attention_map_func(_concate_value.reshape(-1, 2 * hidden_size)))
        _correlat_value = self.correlated_value_ori_func(_correlat_value).reshape(n_batch, n_seq, n_seq)
        # shape of attention_value_format: <n_batch, n_seq, n_seq>       
        attention_value_format = torch.exp(_correlat_value)
        # shape of ensemble flag format: <1, n_seq, n_seq> 
        # if n_seq = 3, ensemble_flag_format can get below flag data
        #  [[[ 0  0  0 
        #      1  0  0
        #      1  1  0 ]]]
        ensemble_flag_format = torch.triu(torch.ones([n_seq, n_seq]), diagonal = 1).permute(1, 0).unsqueeze(0).to(self.device)
        # shape of accumulate_attention_value: <n_batch, n_seq, 1>
        accumulate_attention_value = torch.sum(attention_value_format * ensemble_flag_format, -1).unsqueeze(-1) + 1e-10
        # shape of each_attention_value: <n_batch, n_seq, n_seq>
        each_attention_value = attention_value_format * ensemble_flag_format
        # shape of attention_weight_format: <n_batch, n_seq, n_seq>
        attention_weight_format = each_attention_value/accumulate_attention_value
        # shape of _extend_attention_weight_format: <n_batch, n_seq, n_seq, 1>
        _extend_attention_weight_format = attention_weight_format.unsqueeze(-1)
        # shape of _extend_input_data: <n_batch, 1, n_seq, hidden_size>
        _extend_input_data = input_data.unsqueeze(1)
        # shape of _weighted_input_data: <n_batch, n_seq, n_seq, hidden_size>
        _weighted_input_data = _extend_attention_weight_format * _extend_input_data
        # shape of weighted_output: <n_batch, n_seq, hidden_size>
        weighted_output = torch.sum(_weighted_input_data, 2)
        return weighted_output

class BuildModel(nn.Module):
    def __init__(self, 
                 input_size = None,
                 embed_size = 16,
                 hidden_size = 8,
                 output_size = 10,
                 bias = True,
                 dropout = 0.5,
                 batch_first = True,
                 label_size = 1,
                 attention_type = 'location_based',
                 attention_dim = 8,
                 device = None):
        super(BuildModel, self).__init__()
        assert input_size != None and isinstance(input_size, int), 'fill in correct input_size' 
 
        self.input_size = input_size        
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.label_size = label_size
				
        self.embed_func = nn.Linear(self.input_size, self.embed_size)
        self.rnn_model = nn.GRU(input_size = embed_size,
                                 hidden_size = hidden_size,
                                 bias = bias,
                                 dropout = dropout,
                                 bidirectional = True,
                                 batch_first = batch_first)
        if attention_type == 'location_based':
            self.attention_func = LocationAttention(2*hidden_size, device)
        elif attention_type == 'general':
            self.attention_func = GeneralAttention(2*hidden_size, device)
        elif attention_type == 'concatenation_based':
            self.attention_func = ConcatenationAttention(2*hidden_size, attention_dim = attention_dim, device = device)
        else:
            raise Exception('fill in correct attention_type, [location_based, general, concatenation_based]')
        self.output_func = nn.Linear(4 * hidden_size, self.output_size) 
        self.output_activate = nn.Tanh()
        self.predict_func = nn.Linear(self.output_size, self.label_size)
            
    def forward(self, input_data):
        
        """
        
        Parameters
        
        ----------
        input_data = {
                      'X': shape (batchsize, n_timestep, n_featdim)
                      'M': shape (batchsize, n_timestep)
                      'cur_M': shape (batchsize, n_timestep)
                      'T': shape (batchsize, n_timestep)
                     }
        
        Return
        
        ----------
        
        all_output, shape (batchsize, n_timestep, n_labels)
            
            predict output of each time step
            
        cur_output, shape (batchsize, n_labels)
        
            predict output of last time step
        
        """

        X = input_data['v']
        
        if 'v_lengths' in input_data:
            v_len = input_data['v_lengths']
            mask, cur_mask = self._create_mask_from_visit_length(v_len=v_len)
            mask = mask.to(X.device)
            cur_mask = cur_mask.to(X.device)
        else:
            mask = cur_mask = None

        batchsize, n_timestep, n_orifeatdim = X.shape
        _ori_X = X.view(-1, n_orifeatdim)
        _embed_X = self.embed_func(_ori_X)
        _embed_X = _embed_X.reshape(batchsize, n_timestep, self.embed_size)
        _embed_F, _ = self.rnn_model(_embed_X)
        _embed_F_w = self.attention_func(_embed_F)
        _mix_F = torch.cat([_embed_F, _embed_F_w], dim = -1)
        _mix_F_reshape = _mix_F.view(-1, 4 * self.hidden_size)
        outputs = self.output_activate(self.output_func(_mix_F_reshape)).reshape(batchsize, n_timestep, self.output_size)
        n_batchsize, n_timestep, output_size = outputs.shape

        all_output = self.predict_func(outputs.reshape(n_batchsize*n_timestep, output_size)).\
                         reshape(n_batchsize, n_timestep, self.label_size)

        if mask is not None and cur_mask is not None:
            all_output *= mask.unsqueeze(-1)
            cur_output = (all_output * cur_mask.unsqueeze(-1)).sum(dim=1)
        else:
            cur_output = all_output.sum(1)

        # return all_output, cur_output
        # TODO: all_output is the per visit label prediction, currently we only support patient-level prediction.
        return cur_output

    def _create_mask_from_visit_length(self, v_len):
        mask = torch.zeros(len(v_len),max(v_len))
        cur_mask = torch.zeros(len(v_len),max(v_len))

        for i in range(len(v_len)):
            mask[i, :v_len[i]] = 1
            cur_mask[i, v_len[i]-1] = 1
        return mask, cur_mask

[docs]class Dipole(SequenceIndivBase): ''' Implement Dipole for longitudinal patient records predictive modeling [1]_. Parameters ---------- vocab_size: list[int] A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication. orders: list[str] A list of orders when treating inputs events. Should have the same shape of `vocab_size`. mode: str Prediction traget in ['binary','multiclass','multilabel','regression']. output_dim: int Output dimension of the model. - If binary classification, output_dim=1; - If multiclass/multilabel classification, output_dim=n_class - If regression, output_dim=1. max_visit: int The maximum number of visits for input event codes. attention_type: {'general', 'concatenation_based', 'location_based'} Apply attention mechnism to derive a context vector that captures relevant information to help predict target. - 'location_based': Location-based Attention. Alocation-based attention function is to calculate the weights solely from hidden state - 'general': General Attention. An easy way to capture the relationship between two hidden states - 'concatenation_based': Concatenation-based Attention. Via concatenating two hidden states, then use multi-layer perceptron(MLP) to calculatethe contextvector attention_dim: int It is the latent dimensionality used for attention weight computing just for for concatenation_based attention mechnism emb_size: int Embedding size for encoding input event codes. hidden_size : int, optional (default = 8) The number of features of the hidden state h hidden_output_size : int, optional (default = 8) The number of mix features learning_rate: float Learning rate for optimization based on SGD. Use torch.optim.Adam by default. weight_decay: float Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization. batch_size: int Batch size when doing SGD optimization. epochs: int Maximum number of iterations taken for the solvers to converge. num_worker: int Number of workers used to do dataloading during training. device: str The model device. Notes ----- .. [1] Ma, F., Chitta, R., Zhou, J., You, Q., Sun, T., & Gao, J. (2017, August). Dipole: Diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks. In Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining (pp. 1903-1911). ''' def __init__(self, vocab_size, orders, mode, output_dim=None, max_visit=None, attention_type='location_based', attention_dim=8, emb_size=16, hidden_size=8, hidden_output_size=8, dropout=0.5, learning_rate=1e-4, weight_decay=1e-4, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test', ): super().__init__(experiment_id, mode=mode, output_dim=output_dim) self.config = { 'mode':self.mode, 'vocab_size':vocab_size, 'max_visit':max_visit, 'attention_type':attention_type, 'attention_dim':attention_dim, 'emb_size':emb_size, 'hidden_size':hidden_size, 'hidden_output_size':hidden_output_size, 'output_dim':self.output_dim, 'dropout':dropout, 'device':device, 'learning_rate':learning_rate, 'batch_size':batch_size, 'weight_decay':weight_decay, 'epochs':epochs, 'num_worker':num_worker, 'orders':orders, } self.config['total_vocab_size'] = sum(vocab_size) self.device = device self._build_model()
[docs] def predict(self, test_data): ''' Predict patient outcomes using longitudinal trial patient sequences. Parameters ---------- test_data: SequencePatient A `SequencePatient` contains patient records where 'v' corresponds to visit sequence of different events. ''' test_dataloader = self.get_test_dataloader(test_data) outputs = self._predict_on_dataloader(test_dataloader) return outputs['pred']
[docs] def fit(self, train_data, valid_data): ''' Train model with sequential patient records. Parameters ---------- train_data: SequencePatientBase A `SequencePatientBase` contains patient records where 'v' corresponds to visit sequence of different events; 'y' corresponds to labels. valid_data: SequencePatientBase A `SequencePatientBase` contains patient records used to make early stopping of the model. ''' self._input_data_check(train_data) if valid_data is not None: self._input_data_check(valid_data) self._fit_model(train_data, valid_data=valid_data)
[docs] def load_model(self, checkpoint): ''' Load pretrained model from the disk. Parameters ---------- checkpoint: str The input directory that stores the trained pytorch model and configuration. ''' checkpoint_filename = check_checkpoint_file(checkpoint) config_filename = check_model_config_file(checkpoint) state_dict = torch.load(checkpoint_filename) if config_filename is not None: config = self._load_config(config_filename) self.config.update(config) self.model.load_state_dict(state_dict['model'])
[docs] def save_model(self, output_dir): ''' Save the pretrained model to the disk. Parameters ---------- output_dir: str The output directory that stores the trained pytorch model and configuration. ''' if output_dir is not None: make_dir_if_not_exist(output_dir) else: output_dir = self.checkout_dir self._save_config(self.config, output_dir=output_dir) self._save_checkpoint({'model':self.model.state_dict()}, output_dir=output_dir)
def _build_model(self): self.model = BuildModel( input_size=self.config['total_vocab_size'], embed_size=self.config['emb_size'], hidden_size=self.config['hidden_size'], output_size=self.config['hidden_output_size'], dropout=self.config['dropout'], label_size=self.config['output_dim'], attention_type=self.config['attention_type'], attention_dim=self.config['attention_dim'], device=self.config['device'], ) self.model.to(self.device) def _fit_model(self, train_data, valid_data=None): test_metric_dict = { 'binary': 'auc', 'multiclass': 'acc', 'regression': 'mse', 'multilabel': 'f1', # take average of F1 scores } train_dataloader = self.get_train_dataloader(train_data) loss_models = self._build_loss_model() train_objectives = [(train_dataloader, loss_model) for loss_model in loss_models] if self.config['mode'] == 'regression': less_is_better = True else: less_is_better = False trainer = IndivSeqTrainer(model=self, train_objectives=train_objectives, test_data=valid_data, test_metric=test_metric_dict[self.config['mode']], less_is_better=less_is_better, ) trainer.train(**self.config)