Source code for pytrial.tasks.indiv_outcome.sequence.rnn

import pdb

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import numpy as np

from pytrial.data.patient_data import SeqPatientCollator
from pytrial.utils.check import (
    check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist
)
from .base import SequenceIndivBase, InputEventEmbedding, RNNModel
from ..data import SequencePatient
from ..losses import XentLoss, BinaryXentLoss, MSELoss, MultilabelBinaryXentLoss
from ..trainer import IndivSeqTrainer


class BuildModel(nn.Module):
    def __init__(self,
        rnn_type,
        emb_size,
        bidirectional,
        vocab_size,
        orders,
        n_rnn_layer,
        output_dim,
        **kwargs,
        ) -> None:
        super().__init__()
        if not isinstance(vocab_size, list): vocab_size = [vocab_size]
        self.rnn = RNNModel(
            rnn_type=rnn_type,
            emb_size=emb_size,
            num_layer=n_rnn_layer,
            bidirectional=bidirectional,
            )
        self.emb_size = emb_size
        self.vocab_size = vocab_size
        self.orders = orders
        self.total_vocab_size = sum(vocab_size)
        self.embedding_matrix = nn.Linear(self.total_vocab_size, emb_size, bias=False)
        if bidirectional:
            self.predictor = nn.Linear(2*emb_size, output_dim)
        else:
            self.predictor = nn.Linear(emb_size, output_dim)

    def forward(self, inputs):
        v = inputs['v']
        v_len = inputs['v_lengths']
        visit_emb = self.embedding_matrix(v) # bs, num_visit, emb_size
        packed_input = pack_padded_sequence(visit_emb, v_len, batch_first=True, enforce_sorted=False)
        packed_output = self.rnn(packed_input)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        out_forward = output[range(len(output)), [l - 1 for l in v_len], :self.emb_size]
        out_reverse = output[:, 0, self.emb_size:]
        out_combined = torch.cat((out_forward, out_reverse), 1) # bs, emb_size or emb_size*2(bidirectional)
        return self.predictor(out_combined)

[docs]class RNN(SequenceIndivBase): ''' Implement an RNN based model for longitudinal patient records predictive modeling. Parameters ---------- vocab_size: list[int] A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication. orders: list[str] A list of orders when treating inputs events. Should have the same shape of `vocab_size`. output_dim: int Output dimension of the model. - If binary classification, output_dim=1; - If multiclass/multilabel classification, output_dim=n_class - If regression, output_dim=1. mode: str Prediction traget in ['binary','multiclass','multilabel','regression']. max_visit: int The maximum number of visits for input event codes. emb_size: int Embedding size for encoding input event codes. n_rnn_layer: int Number of RNN layers for encoding historical events. rnn_type: str Pick RNN types in ['rnn','lstm','gru'] bidirectional: bool If True, it encodes historical events in bi-directional manner. learning_rate: float Learning rate for optimization based on SGD. Use torch.optim.Adam by default. weigth_decay: float Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization. batch_size: int Batch size when doing SGD optimization. epochs: int Maximum number of iterations taken for the solvers to converge. evaluation_steps: int Number of steps to evaluate the model on validation set or report the training loss. num_worker: int Number of workers used to do dataloading during training. device: str The model device. experiment_id: str The prefix when saving the checkpoints during the training. ''' def __init__(self, vocab_size, orders, mode, output_dim=None, max_visit=20, emb_size=64, n_rnn_layer=2, rnn_type='lstm', bidirectional=False, learning_rate=1e-4, weight_decay=1e-4, batch_size=64, epochs=10, evaluation_steps=100, num_worker=0, device='cuda:0', experiment_id='test', ): super().__init__(experiment_id, mode=mode, output_dim=output_dim) self.config = { 'mode':self.mode, 'vocab_size':vocab_size, 'max_visit':max_visit, 'emb_size':emb_size, 'n_rnn_layer':n_rnn_layer, 'rnn_type':rnn_type, 'bidirectional':bidirectional, 'output_dim':self.output_dim, 'device':device, 'learning_rate':learning_rate, 'batch_size':batch_size, 'weight_decay':weight_decay, 'epochs':epochs, 'num_worker':num_worker, 'orders':orders, 'evaluation_steps':evaluation_steps, } self.config['total_vocab_size'] = sum(vocab_size) self.device = device self._build_model()
[docs] def fit(self, train_data, valid_data=None): ''' Train model with sequential patient records. Parameters ---------- train_data: SequencePatientBase A `SequencePatientBase` contains patient records where 'v' corresponds to visit sequence of different events; 'y' corresponds to labels. valid_data: SequencePatientBase A `SequencePatientBase` contains patient records used to make early stopping of the model. ''' self._input_data_check(train_data) if valid_data is not None: self._input_data_check(valid_data) self._fit_model(train_data, valid_data=valid_data)
[docs] def predict(self, test_data): ''' Predict patient outcomes using longitudinal trial patient sequences. Parameters ---------- test_data: SequencePatientBase A `SequencePatientBase` contains patient records where 'v' corresponds to visit sequence of different events. ''' test_dataloader = self.get_test_dataloader(test_data) outputs = self._predict_on_dataloader(test_dataloader) return outputs['pred']
[docs] def load_model(self, checkpoint): ''' Load pretrained model from the disk. Parameters ---------- checkpoint: str The input directory that stores the trained pytorch model and configuration. ''' checkpoint_filename = check_checkpoint_file(checkpoint) config_filename = check_model_config_file(checkpoint) state_dict = torch.load(checkpoint_filename) if config_filename is not None: config = self._load_config(config_filename) self.config.update(config) self.model.load_state_dict(state_dict['model'])
[docs] def save_model(self, output_dir): ''' Save the pretrained model to the disk. Parameters ---------- output_dir: str The output directory that stores the trained pytorch model and configuration. ''' if output_dir is not None: make_dir_if_not_exist(output_dir) else: output_dir = self.checkout_dir self._save_config(self.config, output_dir=output_dir) self._save_checkpoint({'model':self.model.state_dict()}, output_dir=output_dir)
def get_test_dataloader(self, test_data): dataloader = DataLoader(test_data, batch_size=self.config['batch_size'], num_workers=self.config['num_worker'], pin_memory=True, shuffle=False, collate_fn=SeqPatientCollator( config={ 'visit_mode':test_data.metadata['visit']['mode'], 'label_mode':test_data.metadata['label']['mode'], } ), ) return dataloader def get_train_dataloader(self, train_data): dataloader = DataLoader(train_data, batch_size=self.config['batch_size'], num_workers=self.config['num_worker'], pin_memory=False, shuffle=True, collate_fn=SeqPatientCollator( config={ 'visit_mode':train_data.metadata['visit']['mode'], 'label_mode':train_data.metadata['label']['mode'], } ), ) return dataloader def _build_model(self): self.model = BuildModel( rnn_type=self.config['rnn_type'], emb_size=self.config['emb_size'], max_visit=self.config['max_visit'], n_rnn_layer=self.config['n_rnn_layer'], bidirectional=self.config['bidirectional'], vocab_size=self.config['vocab_size'], orders=self.config['orders'], output_dim=self.config['output_dim'], ) self.model.to(self.device) @torch.no_grad() def _predict_on_dataloader(self, test_dataloader): pred_list, label_list = [], [] for batch in test_dataloader: inputs = self._prepare_input(batch) logits = self.model(inputs) pred_list.append(logits) if 'y' in batch: label_list.append(batch.pop('y')) pred = torch.cat(pred_list, dim=0) if self.config['mode'] in ['binary','multilabel']: pred = torch.sigmoid(pred) if self.config['mode'] == 'multiclass': pred = torch.softmax(pred, dim=1) pred = pred.cpu().numpy() label = torch.cat(label_list) if len(label_list) > 0 else None return {'pred':pred,'label':label} def _prepare_input(self, data): ''' Prepare inputs for sequential patient record predictive models. Parameters ---------- data: dict[list] A batch of patient records. ''' visits = data['v'] if 'x' in data: feature = data['x'] if not isinstance(feature, torch.Tensor): feature = torch.tensor(feature) feature = feature.to(self.device) else: feature = None inputs = { 'v':{}, 'v_lengths':[], 'x':feature, # baseline feature } v_lengths = [len(visits[self.config['orders'][0]][idx][:self.config['max_visit']]) for idx in range(len(visits[self.config['orders'][0]]))] inputs['v_lengths'] = v_lengths v = torch.zeros(len(v_lengths), max(v_lengths), self.config['total_vocab_size']) for idx in range(len(v_lengths)): v[idx,:v_lengths[idx]] = torch.tensor(self._translate_dense_visits_to_sparse({k: visits[k][idx][:self.config['max_visit']] for k in visits})) v = v.to(self.device) inputs['v'] = v if 'y' in data: # target labels target = data['y'] if not isinstance(target, torch.Tensor): target = torch.tensor(target) target = target.to(self.device) inputs['y'] = target return inputs def _translate_dense_visits_to_sparse(self, visits): total_vocab_size = sum(self.config['vocab_size']) num_visits = len(visits[self.config['orders'][0]]) outputs = np.zeros((num_visits, total_vocab_size)) for i, o in enumerate(self.config['orders']): for j in range(num_visits): raw = visits[o][j] if isinstance(raw, torch.Tensor): raw = raw.detach().cpu().numpy() if i > 0: voc_size = sum(self.config['vocab_size'][:i-1]) if isinstance(raw, list): raw = [r + voc_size for r in raw] else: raw += voc_size outputs[j, raw] = 1 return outputs def _fit_model(self, train_data, valid_data=None): test_metric_dict = { 'binary': 'auc', 'multiclass': 'acc', 'regression': 'mse', 'multilabel': 'f1', # take average of F1 scores } train_dataloader = self.get_train_dataloader(train_data) loss_models = self._build_loss_model() train_objectives = [(train_dataloader, loss_model) for loss_model in loss_models] if self.config['mode'] == 'regression': less_is_better = True else: less_is_better = False trainer = IndivSeqTrainer(model=self, train_objectives=train_objectives, test_data=valid_data, test_metric=test_metric_dict[self.config['mode']], less_is_better=less_is_better, ) trainer.train(**self.config) def _build_loss_model(self): mode = self.config['mode'] if mode == 'binary': return [BinaryXentLoss(self.model)] if mode == 'multiclass': return [XentLoss(self.model)] if mode == 'multilabel': return [MultilabelBinaryXentLoss(self.model)] if mode == 'regression': return [MSELoss(self.model)]