Source code for pytrial.tasks.trial_simulation.sequence.rnn_gan

from collections import defaultdict
import pdb
import time

import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from pytrial.data.patient_data import SequencePatientBase, SeqPatientCollator
from pytrial.utils.check import (
    check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist
)
from .base import SequenceSimulationBase
from .base import InputEventEmbedding, GAN, RNN
from ..losses import GeneratorLoss, DiscriminatorLoss, DiscriminatorLossGP
from ..trainer import SeqSimGANTrainer
from ..data import SequencePatient

class BuildModel(nn.Module):
    def __init__(self,
        rnn_type,
        emb_size,
        bidirectional,
        vocab_size,
        orders,
        n_rnn_layer,
        padding_idx,
        **kwargs,
        ) -> None:
        super().__init__()
        if not isinstance(vocab_size, list): vocab_size = [vocab_size]

        # generator dim size
        # emb_size*2, gendim[0], gendim[2]: default
        # emb_size*4, gendim[0], gendim[1]: directional
        if bidirectional:
            gan_input_emb_size = 2*emb_size
            gen_dims=[gan_input_emb_size, 2*emb_size]
        else:
            gan_input_emb_size = emb_size
            gen_dims=[2*emb_size, 2*emb_size]

        # discriminator dim size
        dis_dims=[2*emb_size]

        if bidirectional:
            gen_dims = [g*2 for g in gen_dims]

        self.gan_module = GAN(
            emb_size=gan_input_emb_size,
            total_vocab_size=sum(vocab_size),
            gen_dims=gen_dims,
            dis_dims=dis_dims,
        )

        self.rnn_module = RNN(
            rnn_type=rnn_type,
            emb_size=emb_size,
            num_layer=n_rnn_layer,
            bidirectional=bidirectional,
        )

        self.vocab_size = vocab_size
        self.orders = orders

        self.embeddings = InputEventEmbedding(orders=orders, vocab_size=vocab_size, emb_size=emb_size, padding_idx=padding_idx)

    def forward(self, inputs, n=1):
        '''
        inputs is a dict
        {
            'v': {'eventA':[],'eventB':[],...,},
            'x': tensor(),
            'y': {'eventA:[], 'eventB':[],...,} # optional
        }
        '''
        embs = self.embeddings(inputs) # [num_visit, num_event, emb_size]
        embs = torch.sum(embs, 1) # [num_visit, emb_size]
        embs = self.rnn_module(embs) # [num_visit, emb_size] or [num_visit, emb_size*2] (bidirectional)
        last_visit_emb = embs[-1] # [emb_size]

        if len(last_visit_emb.shape) == 1:
            last_visit_emb = last_visit_emb.unsqueeze(0)

        if n > 1:
            # generate more than one visits
            last_visit_emb = last_visit_emb.expand(n, -1)

        # infer generator
        z_random = torch.randn(last_visit_emb.size()).to(last_visit_emb.device)
        x_fake = self.gan_module.infer_generator(z_random, last_visit_emb) # 1, vocab_size
        
        # infer discriminator for real records
        if 'y' in inputs:
            # infer discriminator for fake records
            y_fake = self.gan_module.infer_discriminator(x_fake)
            target = self._create_multilabel_target(inputs)
            target = target.to(x_fake.device)
            y_real = self.gan_module.infer_discriminator(target.float())
            return {'x_fake':x_fake, 'y_real':y_real, 'y_fake':y_fake, 'y':target}
        else:
            return {'x_fake':x_fake}
    
    def infer_discriminator(self, x):
        return self.gan_module.infer_discriminator(x)

    def _create_multilabel_target(self, inputs):
        # for one patient one visit only `inputs`
        target = inputs['y']
        target_mo_list = []
        for i,o in enumerate(self.orders):
            target_mo = torch.zeros(self.vocab_size[i], dtype=torch.long)
            target_mo[target[o]] = 1
            target_mo_list.append(target_mo)
        targets = torch.cat(target_mo_list)
        if len(targets.shape) == 1: targets = targets.unsqueeze(0)
        return targets


[docs]class RNNGAN(SequenceSimulationBase): ''' Implement an RNN based GAN model for longitudinal patient records simulation. The GAN part was proposed by Choi et al. [1]_. Parameters ---------- vocab_size: list[int] A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication. order: list[str] The order of event types in each visits, e.g., ``['diag', 'prod', 'med']``. Visit = [diag_events, prod_events, med_events], each event is a list of codes. max_visit: int The maximum number of visits for input event codes. emb_size: int Embedding size for encoding input event codes. n_rnn_layer: int Number of RNN layers for encoding historical events. rnn_type: str Pick RNN types in ['rnn','lstm','gru'] bidirectional: bool If True, it encodes historical events in bi-directional manner. padding_idx: int(default=None) Set the padding index for input events embedding. If set None, then no padding index will be specified. learning_rate: float Learning rate for optimization based on SGD. Use torch.optim.Adam by default. weigth_decay: float Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization. batch_size: int Batch size when doing SGD optimization. epochs: int Maximum number of iterations taken for the solvers to converge. num_worker: int Number of workers used to do dataloading during training. device: str The model device. Notes ----- .. [1] Choi, E., et al. (2017, November). Generating multi-label discrete patient records using generative adversarial networks. In ML4HC (pp. 286-305). PMLR. ''' def __init__(self, vocab_size, order, max_visit=20, emb_size=64, n_rnn_layer=2, rnn_type='lstm', bidirectional=False, padding_idx=None, learning_rate=1e-4, weight_decay=1e-4, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='trial_simulation.sequence.rnn_gan', ): super().__init__(experiment_id) self.config = { 'vocab_size':vocab_size, 'max_visit':max_visit, 'emb_size':emb_size, 'n_rnn_layer':n_rnn_layer, 'rnn_type':rnn_type, 'bidirectional':bidirectional, 'padding_idx':padding_idx, 'device':device, 'learning_rate':learning_rate, 'batch_size':batch_size, 'weight_decay':weight_decay, 'epochs':epochs, 'num_worker':num_worker, 'orders':order, } self.config['total_vocab_size'] = sum(vocab_size) self.device = device self._build_model()
[docs] def fit(self, train_data): ''' Train model with sequential patient records. Parameters ---------- train_data: SequencePatientBase A `SequencePatientBase` contains patient records where 'v' corresponds to visit sequence of different events. ''' self._input_data_check(train_data) self._fit_model(train_data)
[docs] def predict(self, test_data, n=None, n_per_sample=None, return_tensor=True): ''' Generate synthetic records based on input real patient seq data. Parameters ---------- test_data: SequencePatientBase A `SequencePatientBase` contains patient records where 'v' corresponds to visit sequence of different events. n: int How many samples in total will be generated. n_per_sample: int How many samples generated based on each indivudals. return_tensor: bool If `True`, return output generated records in tensor format (n, n_visit, n_event), good for later predictive modeling. If `False, return records in `SequencePatient` format. ''' if n is not None: assert isinstance(n, int), 'Input `n` should be integer.' if n_per_sample is not None: assert isinstance(n_per_sample, int), 'Input `n_per_sample` should be integer.' n, n_per_sample = self._compute_n_per_sample(len(test_data), n, n_per_sample) test_dataloader = self.get_test_dataloader(test_data) outputs = self._predict_on_dataloader(test_dataloader, n, n_per_sample) if not return_tensor: outputs = self._translate_sparse_visits_to_dense(outputs) else: # pad all to same shape outputs = self._pad_multiple_tensor_visits(outputs) return outputs
[docs] def save_model(self, output_dir): ''' Save the learned simulation model to the disk. Parameters ---------- output_dir: str or None The dir to save the learned model. If set None, will save model to `self.checkout_dir`. ''' if output_dir is not None: make_dir_if_not_exist(output_dir) else: output_dir = self.checkout_dir self._save_config(self.config, output_dir=output_dir) self._save_checkpoint({'model':self.model}, output_dir=output_dir)
[docs] def load_model(self, checkpoint): ''' Load model and the pre-encoded trial embeddings from the given checkpoint dir. Parameters ---------- checkpoint: str The input dir that stores the pretrained model. - If a directory, the only checkpoint file `*.pth.tar` will be loaded. - If a filepath, will load from this file. ''' checkpoint_filename = check_checkpoint_file(checkpoint) config_filename = check_model_config_file(checkpoint) state_dict = torch.load(checkpoint_filename) if config_filename is not None: config = self._load_config(config_filename) self.config.update(config) self.model = state_dict['model']
def get_train_dataloader(self, train_data): dataloader = DataLoader(train_data, batch_size=self.config['batch_size'], num_workers=self.config['num_worker'], pin_memory=True, shuffle=True, collate_fn=SeqPatientCollator( config={ 'visit_mode':train_data.metadata['visit']['mode'], 'label_mode':train_data.metadata['label']['mode'], } ), ) return dataloader def get_test_dataloader(self, train_data): dataloader = DataLoader(train_data, batch_size=self.config['batch_size'], num_workers=self.config['num_worker'], pin_memory=False, shuffle=False, collate_fn=SeqPatientCollator( config={ 'visit_mode':train_data.metadata['visit']['mode'], 'label_mode':train_data.metadata['label']['mode'], } ), ) return dataloader def _build_model(self): self.model = BuildModel( rnn_type=self.config['rnn_type'], emb_size=self.config['emb_size'], max_visit=self.config['max_visit'], n_rnn_layer=self.config['n_rnn_layer'], bidirectional=self.config['bidirectional'], vocab_size=self.config['vocab_size'], orders=self.config['orders'], padding_idx=self.config['padding_idx'], ) self.model.to(self.device) def _build_loss_model(self): # update discriminator two times v.s. generator one time return [DiscriminatorLoss(self.model), GeneratorLoss(self.model)] def _fit_model(self, train_data): train_dataloader = self.get_train_dataloader(train_data) loss_models = self._build_loss_model() train_objectives = [(train_dataloader, loss_model) for loss_model in loss_models] trainer = SeqSimGANTrainer( model=self, train_objectives=train_objectives ) trainer.train(**self.config) @torch.no_grad() def _predict_on_dataloader(self, test_dataloader, n, n_per_sample): data_iterator = iter(test_dataloader) total_number = 0 fake_visit_list = [] while total_number < n: try: data = next(data_iterator) except: data_iterator = iter(test_dataloader) data = next(data_iterator) for idx, _ in enumerate(data['x']): num_visit = self._get_num_visit(data, idx) if num_visit < 2: # deal with more than one visit only continue fake_visits = [] for vdx in range(1, num_visit): inputs = self._prepare_input(data, idx, vdx) if 'y' in inputs: inputs.pop('y') x_fake = self.model(inputs, n=n_per_sample)['x_fake'] fake_visits.append(x_fake.cpu().numpy()) fake_visits = np.stack(fake_visits, 1) # n_sample, n_visit, total_vocab_size fake_visits[fake_visits>0.5]=1 fake_visits[fake_visits<=0.5]=0 outputs = self._prepare_input(data, idx, 1) first_visit = self._translate_dense_visits_to_sparse(outputs['v']) first_visit = np.tile(first_visit[None],(len(fake_visits),1,1)) fake_visits = np.concatenate([first_visit, fake_visits], 1) fake_visit_list.append(fake_visits) # add one synthetic record total_number += len(fake_visits) if total_number >= n: break return fake_visit_list def _translate_sparse_visits_to_dense(self, visits): def _map_func(x): res = np.where(x > 0)[0].tolist() return [0] if len(res) == 0 else res # pad if nothing happened outputs = defaultdict(list) for batchv in visits: voc_offset = 0 for i, o in enumerate(self.config['orders']): voc_size = self.config['vocab_size'][i] visit = batchv[...,voc_offset:voc_offset+voc_size] # 10, 7, 5 for visit_ in visit: res = list(map(_map_func, visit_)) outputs[o].append(res) voc_offset += voc_size n_total = len(outputs[o]) sample_list = [] for i in range(n_total): sample = [] for numv in range(len(outputs[o][i])): visit = [] for o in self.config['orders']: visit.append(outputs[o][i][numv]) sample.append(visit) sample_list.append(sample) # create seqpatient data return SequencePatient( data={'v':sample_list}, metadata={ 'visit':{'mode':'dense','order':self.config['orders']}, } ) def _translate_dense_visits_to_sparse(self, visits): total_vocab_size = sum(self.config['vocab_size']) num_visits = len(visits[self.config['orders'][0]]) outputs = np.zeros((num_visits, total_vocab_size)) for i, o in enumerate(self.config['orders']): for j in range(num_visits): raw = visits[o][j] if isinstance(raw, torch.Tensor): raw = raw.detach().cpu().numpy() if i > 0: voc_size = sum(self.config['vocab_size'][:i-1]) if isinstance(raw, list): raw = [r + voc_size for r in raw] else: raw += voc_size outputs[j, raw] = 1 return outputs def _pad_multiple_tensor_visits(self, visits): new_list = [] for v in visits: new_list.extend([torch.tensor(x).squeeze(0) for x in np.array_split(v, len(v))]) return pad_sequence(new_list, batch_first=True) def _input_data_check(self, inputs): assert isinstance(inputs, SequencePatientBase), f'`trial_simulation.sequence` models require input training data in `SequencePatientBase`, find {type(inputs)} instead.'