Source code for pytrial.tasks.trial_simulation.sequence.eva

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

import numpy as np
from collections import defaultdict

from pytrial.data.patient_data import SequencePatientBase, SeqPatientCollator
from pytrial.utils.check import (
    check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist
)
from .base import SequenceSimulationBase
from ..losses import MultilabelBinaryXentLossWithKLDivergence
from ..trainer import SeqSimEvaTrainer
from ..data import SequencePatient

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        super(CausalConv1d, self).__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.pad, dilation=dilation, **kwargs)

    def forward(self, input):
        return self.conv(input)[:,:,:-self.conv.padding[0]]

def connector(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return eps * std + mu

class Encoder(nn.Module):
    def __init__(self, emb_size, latent_dim, n_rnn_layer, total_vocab_size):
        super(Encoder, self).__init__()
        self.hidden_dim = emb_size
        self.embedding_matrix = nn.Linear(total_vocab_size, emb_size, bias=False)
        self.lstm = nn.LSTM(input_size=emb_size,
                            hidden_size=emb_size,
                            num_layers=n_rnn_layer,
                            bidirectional=True,
                            batch_first=True)
        self.latent_encoder = nn.Linear(2*emb_size, 2*latent_dim)

    def forward(self, input, lengths):
        visit_emb = self.embedding_matrix(input)
        packed_input = pack_padded_sequence(visit_emb, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_input)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        out_forward = output[range(len(output)), [l - 1 for l in lengths], :self.hidden_dim]
        out_reverse = output[:, 0, self.hidden_dim:]
        out_combined = torch.cat((out_forward, out_reverse), 1)
        mean_logvar = self.latent_encoder(out_combined)
        return mean_logvar

class Decoder(nn.Module):
    def __init__(self, emb_size, latent_dim, total_vocab_size, max_visit):
        super(Decoder, self).__init__()
        self.max_visit = max_visit
        self.deconv_list = nn.ModuleList([nn.ConvTranspose1d(latent_dim if i == 0 else emb_size, emb_size, 3, stride=3) for i in range(int(np.ceil(np.power(max_visit, 1/3))))])
        self.causal_conv1 = CausalConv1d(emb_size, emb_size, 5, dilation=2)
        self.causal_conv2 = CausalConv1d(emb_size, 2*emb_size, 5, dilation=2)
        self.causal_conv3 = CausalConv1d(2*emb_size, total_vocab_size, 5, dilation=2)

    def forward(self, input):
        out = input.unsqueeze(2)
        for deconv in self.deconv_list:
          out = deconv(out)
        out = out[:,:,:self.max_visit]
        out = self.causal_conv1(out)
        out = self.causal_conv2(out)
        out = self.causal_conv3(out)
        out = out.transpose(1, 2)
        return out
      
class BuildModel(nn.Module):
    def __init__(self,
        max_visit,
        emb_size,
        latent_dim,
        n_rnn_layer,
        total_vocab_size,
        device,
        **kwargs,
        ) -> None:
        super().__init__()
        
        self.latent_dim = latent_dim
        self.device = device

        self.encoder_module = Encoder(
            emb_size=emb_size, 
            latent_dim=latent_dim, 
            n_rnn_layer=n_rnn_layer, 
            total_vocab_size=total_vocab_size,
        )

        self.decoder_module = Decoder(
            emb_size=emb_size, 
            latent_dim=latent_dim, 
            total_vocab_size=total_vocab_size, 
            max_visit=max_visit
        )
      
    def forward(self, inputs):
        x, input_lengths = inputs['v'], inputs['v_lengths']
        mean_logvar = self.encoder_module(x, input_lengths)
        mu = mean_logvar[:,:self.latent_dim]
        log_var = mean_logvar[:,self.latent_dim:]
        kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        decoder_inputs = connector(mu, log_var)
        code_logits = self.decoder_module(decoder_inputs)
        for idx in range(len(inputs['v_lengths'])):
            code_logits[idx,inputs['v_lengths'][idx]:] = 0
        return code_logits[:,:inputs['v'].size(1)], inputs['v'], kl_loss
      
    def sample(self, n_samples):
        decoder_inputs = torch.randn((n_samples, self.latent_dim)).to(self.device)
        code_logits = self.decoder_module(decoder_inputs)
        sig = nn.Sigmoid()
        code_probs = sig(code_logits)
        patient_records = torch.bernoulli(code_probs)
        return patient_records
    
    def sample_from(self, x, input_lengths, n_per_sample):
        mean_logvar = self.encoder_module(x, input_lengths)
        mu = mean_logvar[:,:self.latent_dim]
        log_var = mean_logvar[:,self.latent_dim:]
        code_probs = []
        sig = nn.Sigmoid()
        for _ in range(n_per_sample):
            decoder_inputs = connector(mu, log_var)
            code_logits = self.decoder_module(decoder_inputs)
            code_probs.append(sig(code_logits))
        code_probs = torch.cat(code_probs)
        patient_records = torch.bernoulli(code_probs)
        return patient_records

[docs]class EVA(SequenceSimulationBase): ''' Implement a VAE based model for longitudinal patient records simulation [1]_. Parameters ---------- vocab_size: list[int] A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication. order: list[str] The order of event types in each visits, e.g., ``['diag', 'prod', 'med']``. Visit = [diag_events, prod_events, med_events], each event is a list of codes. max_visit: int Maximum number of visits. emb_size: int Embedding size for encoding input event codes. latent_dim: int Size of final latent dimension between the encoder and decoder n_rnn_layer: int Number of RNN layers for encoding historical events. learning_rate: float Learning rate for optimization based on SGD. Use torch.optim.Adam by default. batch_size: int Batch size when doing SGD optimization. epochs: int Maximum number of iterations taken for the solvers to converge. num_worker: int Number of workers used to do dataloading during training. Notes ----- .. [1] Biswal, S., et al. (2020, December). EVA: Generating Longitudinal Electronic Health Records Using Conditional Variational Autoencoders. ''' def __init__(self, vocab_size, order, max_visit=20, emb_size=64, latent_dim=32, n_rnn_layer=2, learning_rate=1e-3, batch_size=64, epochs=20, num_worker=0, device='cpu',# 'cuda:0', experiment_id='trial_simulation.sequence.eva', ): super().__init__(experiment_id) self.config = { 'vocab_size':vocab_size, 'max_visit':max_visit, 'emb_size':emb_size, 'latent_dim':latent_dim, 'n_rnn_layer':n_rnn_layer, 'device':device, 'learning_rate':learning_rate, 'batch_size':batch_size, 'epochs':epochs, 'num_worker':num_worker, 'orders':order, } self.config['total_vocab_size'] = sum(vocab_size) self.device = device self._build_model() def _build_model(self): self.model = BuildModel( max_visit=self.config['max_visit'], emb_size=self.config['emb_size'], latent_dim=self.config['latent_dim'], n_rnn_layer=self.config['n_rnn_layer'], total_vocab_size=self.config['total_vocab_size'], orders=self.config['orders'], device=self.device ) self.model = self.model.to(self.device)
[docs] def fit(self, train_data): ''' Train model with sequential patient records. Parameters ---------- train_data: SequencePatientBase A `SequencePatientBase` contains patient records where 'v' corresponds to visit sequence of different events. ''' self._input_data_check(train_data) self._fit_model(train_data)
[docs] def predict(self, n, return_tensor=False): ''' Generate synthetic records Parameters ---------- n: int How many samples in total will be generated. return_tensor: bool - If `True`, return output generated records in tensor format (n, n_visit, n_event), good for later predictive modeling. - If `False, return records in `SequencePatient` format. ''' assert isinstance(n, int), 'Input `n` should be integer.' outputs = self._predict(n) if not return_tensor: outputs = self._translate_sparse_visits_to_dense(outputs) return outputs
[docs] def save_model(self, output_dir): ''' Save the learned simulation model to the disk. Parameters ---------- output_dir: str or None The dir to save the learned model. If set None, will save model to `self.checkout_dir`. ''' if output_dir is not None: make_dir_if_not_exist(output_dir) else: output_dir = self.checkout_dir self._save_config(self.config, output_dir=output_dir) self._save_checkpoint({ 'encoder': self.model.encoder_module.state_dict(), 'decoder': self.model.decoder_module.state_dict() }, output_dir=output_dir)
[docs] def load_model(self, checkpoint): ''' Load model and the pre-encoded trial embeddings from the given checkpoint dir. Parameters ---------- checkpoint: str The input dir that stores the pretrained model. - If a directory, the only checkpoint file `*.pth.tar` will be loaded. - If a filepath, will load from this file. ''' checkpoint_filename = check_checkpoint_file(checkpoint) config_filename = check_model_config_file(checkpoint) state_dict = torch.load(checkpoint_filename, map_location=self.config['device']) if config_filename is not None: config = self._load_config(config_filename) self.config.update(config) self.model.encoder_module.load_state_dict(state_dict['encoder']) self.model.decoder_module.load_state_dict(state_dict['decoder'])
def get_train_dataloader(self, train_data): dataloader = DataLoader(train_data, batch_size=self.config['batch_size'], num_workers=self.config['num_worker'], pin_memory=True, shuffle=True, collate_fn=SeqPatientCollator( config={ 'visit_mode':train_data.metadata['visit']['mode'], 'label_mode':train_data.metadata['label']['mode'], } ), ) return dataloader def get_test_dataloader(self, train_data): dataloader = DataLoader(train_data, batch_size=self.config['batch_size'], num_workers=self.config['num_worker'], pin_memory=False, shuffle=False, collate_fn=SeqPatientCollator( config={ 'visit_mode':train_data.metadata['visit']['mode'], 'label_mode':train_data.metadata['label']['mode'], } ), ) return dataloader def _build_loss_model(self): return MultilabelBinaryXentLossWithKLDivergence(self.model) def _fit_model(self, train_data): train_dataloader = self.get_train_dataloader(train_data) loss_model = self._build_loss_model() train_objectives = [(train_dataloader, loss_model)] trainer = SeqSimEvaTrainer( model=self, train_objectives=train_objectives ) trainer.train(**self.config) @torch.no_grad() def _predict(self, n): return self.model.sample(n) @torch.no_grad() def _predict_on_dataloader(self, test_dataloader, n, n_per_sample): data_iterator = iter(test_dataloader) outputs = [] for data in data_iterator: outputs.append(self.model.sample_from(data['v'], data['v_lengths'], n_per_sample)) return torch.cat(outputs)[:n] def _translate_sparse_visits_to_dense(self, visits): def _map_func(x): res = np.where(x > 0)[0].tolist() return [0] if len(res) == 0 else res # pad if nothing happened outputs = defaultdict(list) for batchv in visits: voc_offset = 0 for i, o in enumerate(self.config['orders']): voc_size = self.config['vocab_size'][i] visit = batchv[...,voc_offset:voc_offset+voc_size] visit = visit.cpu().numpy() res = list(map(_map_func, visit)) outputs[o].append(res) voc_offset += voc_size n_total = len(outputs[o]) sample_list = [] for i in range(n_total): sample = [] for numv in range(len(outputs[o][i])): visit = [] for o in self.config['orders']: visit.append(outputs[o][i][numv]) sample.append(visit) sample_list.append(sample) # create seqpatient data return SequencePatient( data={'v':sample_list}, metadata={ 'visit':{'mode':'dense','order':self.config['orders']}, } ) def _translate_dense_visits_to_sparse(self, visits): total_vocab_size = sum(self.config['vocab_size']) num_visits = len(visits[self.config['orders'][0]]) outputs = np.zeros((num_visits, total_vocab_size)) for i, o in enumerate(self.config['orders']): for j in range(num_visits): raw = visits[o][j] if isinstance(raw, torch.Tensor): raw = raw.detach().cpu().numpy() if i > 0: voc_size = sum(self.config['vocab_size'][:i-1]) if isinstance(raw, list): raw = [r + voc_size for r in raw] else: raw += voc_size outputs[j, raw] = 1 return outputs def _pad_multiple_tensor_visits(self, visits): new_list = [] for v in visits: new_list.extend([torch.tensor(x).squeeze(0) for x in np.array_split(v, len(v))]) return pad_sequence(new_list, batch_first=True) def _prepare_input(self, data): ''' Prepare inputs for sequence simulation models. Parameters ---------- data: dict[list] A batch of patient records. ''' visits = data['v'] feature = data['x'] if not isinstance(feature, torch.Tensor): feature = torch.tensor(feature) feature = feature.to(self.device) inputs = { 'v':{}, 'v_lengths':[], 'x':feature, # baseline feature } v_lengths = [len(visits[self.config['orders'][0]][idx][:self.config['max_visit']]) for idx in range(len(visits[self.config['orders'][0]]))] inputs['v_lengths'] = v_lengths v = torch.zeros(len(v_lengths), max(v_lengths), self.config['total_vocab_size']) for idx in range(len(v_lengths)): v[idx,:v_lengths[idx]] = torch.tensor(self._translate_dense_visits_to_sparse({k: visits[k][idx][:self.config['max_visit']] for k in visits})) v = v.to(self.device) inputs['v'] = v return inputs def _input_data_check(self, inputs): assert isinstance(inputs, SequencePatientBase), f'`trial_simulation.sequence` models require input training data in `SequencePatientBase`, find {type(inputs)} instead.'