Source code for pytrial.tasks.indiv_outcome.sequence.retain

import pdb

import torch
from torch import nn

from pytrial.utils.check import (
    check_checkpoint_file, check_model_config_file, make_dir_if_not_exist
)
from .base import SequenceIndivBase
from ..trainer import IndivSeqTrainer


class RETAINLayer(nn.Module):
    """The separate callable RETAIN layer.
    Args:
        input_size: the embedding size of the input
        output_size: the embedding size of the output
        num_layers: the number of layers in the RNN
        dropout: dropout rate
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        dropout: float = 0.5,
        ):
        super(RETAINLayer, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.dropout_layer = nn.Dropout(p=self.dropout)

        self.alpha_gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.beta_gru = nn.GRU(input_size, hidden_size, batch_first=True)

        self.alpha_li = nn.Linear(hidden_size, 1)
        self.beta_li = nn.Linear(hidden_size, hidden_size)

    def forward(self, inputs):
        """Using the sum of the embedding as the output of the transformer
        Args:
            x: [batch size, seq len, input_size]
            mask: [batch size, seq len]
        Returns:
            outputs [batch size, seq len, hidden_size]
        """
        x = inputs['x']
        mask = inputs['mask']

        # rnn will only apply dropout between layers
        x = self.dropout_layer(x)

        g, _ = self.alpha_gru(x)  # (patient, seq_len, hidden_size)
        h, _ = self.beta_gru(x)  # (patient, seq_len, hidden_size)

        # TOFIX: mask out the visit (by adding a large negative number 1e10)
        # however, it does not work better than not mask out
        attn_g = self.alpha_li(g)

        if mask is not None:
            attn_g = self._apply_mask_to_attention(attn_g, mask)

        attn_g = torch.softmax(attn_g, dim=1)  # (patient, seq_len, 1)
        attn_h = torch.tanh(self.beta_li(h))  # (patient, seq_len, hidden_size)
        c = attn_g * attn_h * x  # (patient, seq_len, hidden_size)
        # c = torch.sum(c, dim=1)  # (patient, hidden_size)
        return {'x':c, 'mask':mask}
    
    def _apply_mask_to_attention(self, attn_g, mask):
        mask_offset = torch.zeros(mask.shape, device=mask.device)
        mask_offset[mask==0] = -1e9
        mask_offset[mask==1] = 0
        attn_g += mask_offset.unsqueeze(-1)
        return attn_g


class BuildModel(nn.Module):
    def __init__(self,
        emb_size,
        hidden_dim,
        vocab_size,
        n_rnn_layer,
        output_dim,
        **kwargs,
        ) -> None:
        super().__init__()
        if not isinstance(vocab_size, list): vocab_size = [vocab_size]
        self.emb_size = emb_size
        self.vocab_size = vocab_size
        self.total_vocab_size = sum(vocab_size)
        self.embedding_matrix = nn.Linear(self.total_vocab_size, emb_size, bias=False)
        rnn_list = []
        for i in range(n_rnn_layer):
            if i == 0:
                rnn = RETAINLayer(
                        input_size=emb_size,
                        hidden_size=hidden_dim,
                        dropout=0.5,
                    )
            
            else:
                rnn = RETAINLayer(
                    input_size=hidden_dim,
                    hidden_size=hidden_dim,
                    dropout=0.5,
                )
            rnn_list.append(rnn)
        self.rnn = nn.Sequential(*rnn_list)
        self.predictor = nn.Linear(hidden_dim, output_dim)

    def forward(self, inputs):
        v = inputs['v']
        visit_emb = self.embedding_matrix(v) # bs, num_visit, emb_size

        if 'v_lengths' in inputs:
            v_len = inputs['v_lengths']
            mask = self._create_mask_from_visit_length(v_len=v_len)
            mask = mask.to(visit_emb.device)
        else:
            mask = None

        output = self.rnn({'x':visit_emb, 'mask':mask}) # bs, seq_len, hidden_size
        output = output['x']
        output = torch.sum(output, dim=1) # bs, emd_dim
        return self.predictor(output)

    def _create_mask_from_visit_length(self, v_len):
        mask = torch.zeros(len(v_len),max(v_len))
        for i in range(len(v_len)):
            mask[i, :v_len[i]] = 1
        return mask


[docs]class RETAIN(SequenceIndivBase): ''' Implement RETAIN for longitudinal patient records predictive modeling [1]_. Parameters ---------- vocab_size: list[int] A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication. orders: list[str] A list of orders when treating inputs events. Should have the same shape of `vocab_size`. output_dim: int The dimension of the output. - If binary classification, output_dim=1; - If multiclass/multilabel classification, output_dim=n_class - If regression, output_dim=1. mode: str Prediction traget in ['binary','multiclass','multilabel','regression']. max_visit: int The maximum number of visits for input event codes. emb_size: int Embedding size for encoding input event codes. n_rnn_layer: int Number of RETAIN layers for encoding historical events. learning_rate: float Learning rate for optimization based on SGD. Use torch.optim.Adam by default. weight_decay: float Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization. batch_size: int Batch size when doing SGD optimization. epochs: int Maximum number of iterations taken for the solvers to converge. num_worker: int Number of workers used to do dataloading during training. device: str The model device. Notes ----- .. [1] Choi, E., Bahadori, M. T., Sun, J., Kulas, J., Schuetz, A., & Stewart, W. (2016). Retain: An interpretable predictive model for healthcare using reverse time attention mechanism. Advances in neural information processing systems, 29. ''' def __init__(self, vocab_size, orders, mode, output_dim=None, max_visit=None, emb_size=64, n_rnn_layer=2, learning_rate=1e-4, weight_decay=1e-4, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test', ): super().__init__(experiment_id, mode=mode, output_dim=output_dim) self.config = { 'mode':self.mode, 'vocab_size':vocab_size, 'max_visit':max_visit, 'emb_size':emb_size, 'hidden_dim':emb_size, 'n_rnn_layer':n_rnn_layer, 'output_dim':self.output_dim, 'device':device, 'learning_rate':learning_rate, 'batch_size':batch_size, 'weight_decay':weight_decay, 'epochs':epochs, 'num_worker':num_worker, 'orders':orders, } self.config['total_vocab_size'] = sum(vocab_size) self.device = device self._build_model()
[docs] def fit(self, train_data, valid_data=None): ''' Train model with sequential patient records. Parameters ---------- train_data: SequencePatientBase A `SequencePatientBase` contains patient records where 'v' corresponds to visit sequence of different events; 'y' corresponds to labels. valid_data: SequencePatientBase A `SequencePatientBase` contains patient records used to make early stopping of the model. ''' self._input_data_check(train_data) if valid_data is not None: self._input_data_check(valid_data) self._fit_model(train_data, valid_data=valid_data)
[docs] def predict(self, test_data): ''' Predict patient outcomes using longitudinal trial patient sequences. Parameters ---------- test_data: SequencePatient A `SequencePatient` contains patient records where 'v' corresponds to visit sequence of different events. ''' test_dataloader = self.get_test_dataloader(test_data) outputs = self._predict_on_dataloader(test_dataloader) return outputs['pred']
[docs] def load_model(self, checkpoint): ''' Load pretrained model from the disk. Parameters ---------- checkpoint: str The input directory that stores the trained pytorch model and configuration. ''' checkpoint_filename = check_checkpoint_file(checkpoint) config_filename = check_model_config_file(checkpoint) state_dict = torch.load(checkpoint_filename) if config_filename is not None: config = self._load_config(config_filename) self.config.update(config) self.model.load_state_dict(state_dict['model'])
[docs] def save_model(self, output_dir): ''' Save the pretrained model to the disk. Parameters ---------- output_dir: str The output directory that stores the trained pytorch model and configuration. ''' if output_dir is not None: make_dir_if_not_exist(output_dir) else: output_dir = self.checkout_dir self._save_config(self.config, output_dir=output_dir) self._save_checkpoint({'model':self.model.state_dict()}, output_dir=output_dir)
def _build_model(self): self.model = BuildModel( emb_size=self.config['emb_size'], hidden_dim=self.config['hidden_dim'], n_rnn_layer=self.config['n_rnn_layer'], vocab_size=self.config['vocab_size'], output_dim=self.config['output_dim'], ) self.model.to(self.device) def _fit_model(self, train_data, valid_data=None): test_metric_dict = { 'binary': 'auc', 'multiclass': 'acc', 'regression': 'mse', 'multilabel': 'f1', # take average of F1 scores } train_dataloader = self.get_train_dataloader(train_data) loss_models = self._build_loss_model() train_objectives = [(train_dataloader, loss_model) for loss_model in loss_models] if self.config['mode'] == 'regression': less_is_better = True else: less_is_better = False trainer = IndivSeqTrainer(model=self, train_objectives=train_objectives, test_data=valid_data, test_metric=test_metric_dict[self.config['mode']], less_is_better=less_is_better, ) trainer.train(**self.config)