Source code for pytrial.tasks.trial_simulation.sequence.synteg

import pdb
import math
from collections import defaultdict

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import numpy as np
from tqdm import trange

from pytrial.data.patient_data import SequencePatientBase, SeqPatientCollator
from pytrial.utils.check import (
    check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist
)
from .base import SequenceSimulationBase
from ..losses import GeneratorLoss, DiscriminatorLossGP, MultilabelBinaryXentLoss
from ..trainer import SeqSimSyntegTrainer
from ..data import SequencePatient

class Embedding(nn.Module):
    def __init__(self, total_vocab_size, emb_size):
        """Construct an embedding matrix to embed sparse codes"""
        super(Embedding, self).__init__()
        self.code_embed = nn.Embedding(total_vocab_size+3, emb_size)

    def forward(self, codes): # batch_size * visits * codes
        code_embeds = self.code_embed(codes)
        return code_embeds

class SingleVisitTransformer(nn.Module):
    """An Encoder Transformer to turn code embeddings into a visit embedding"""
    def __init__(self, emb_size, n_head, hidden_dim):
        super(SingleVisitTransformer, self).__init__()
        encoderLayer = nn.TransformerEncoderLayer(emb_size, n_head, 
                        dim_feedforward=hidden_dim, dropout=0.1, activation="relu", 
                        layer_norm_eps=1e-08, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoderLayer, 2)

    def forward(self, code_embeddings, visit_lengths):
        bs, vs, cs, ed = code_embeddings.shape
        mask = torch.ones((bs, vs, cs)).to(code_embeddings.device)
        for i in range(bs):
            for j in range(vs):
                mask[i,j,:visit_lengths[i,j]] = 0
        visits = torch.reshape(code_embeddings, (bs*vs,cs,ed))
        mask = torch.reshape(mask, (bs*vs,cs))
        encodings = self.transformer(visits, src_key_padding_mask=mask)
        encodings = torch.reshape(encodings, (bs,vs,cs,ed))
        visit_representations = encodings[:,:,0,:]
        return visit_representations

class RecurrentLayer(nn.Module):
    """An Recurrent Layer to predict the next visit based on the visit embeddings"""
    def __init__(self, hidden_dim, n_rnn_layer):
        super(RecurrentLayer, self).__init__()
        self.lstm = nn.LSTM(input_size=hidden_dim, hidden_size=hidden_dim, num_layers=n_rnn_layer, dropout=0.1)

    def forward(self, visit_embeddings):   
        output, _ = self.lstm(visit_embeddings)
        return output

class DependencyModel(nn.Module):
    """The entire Dependency Model component of SynTEG"""
    def __init__(self, emb_size, hidden_dim, condition_dim, total_vocab_size, n_head, n_rnn_layer):
        super(DependencyModel, self).__init__()
        self.embeddings = Embedding(total_vocab_size, emb_size)
        self.visit_att = SingleVisitTransformer(emb_size, n_head, hidden_dim)
        self.proj1 = nn.Linear(emb_size, hidden_dim)
        self.lstm = RecurrentLayer(hidden_dim, n_rnn_layer)
        self.proj2 = nn.Linear(hidden_dim, condition_dim)
        self.proj3 = nn.Linear(condition_dim, total_vocab_size)
        
    def forward(self, inputs_word, visit_lengths, export=False):  # bs * visits * codes, bs * visits * 1 
        inputs = self.embeddings(inputs_word) # bs * visits * codes * emb_size
        inputs = self.visit_att(inputs, visit_lengths) # bs * visits * emb_size
        inputs = self.proj1(inputs) # bs * visits * hidden_dim
        output = self.lstm(inputs) # bs * visits * hidden_dim
        if export:
            return self.proj2(output)[:, :-1, :] # bs * visit * condition
        else:
            output = self.proj3(torch.relu(self.proj2(output))) # bs * visits * total_vocab_size
            diagnosis_output = output[:, :-1, :]
            return diagnosis_output

#######
### Conditional GAN Model
#######

class PointWiseLayer(nn.Module):
    def __init__(self, num_outputs):
        """Construct an embedding matrix to embed sparse codes"""
        super(PointWiseLayer, self).__init__()
        self.bias = nn.Parameter(torch.zeros(num_outputs).uniform_(-math.sqrt(num_outputs), math.sqrt(num_outputs)))

    def forward(self, x1, x2):
        return x1 * x2 + self.bias

class Generator(nn.Module):
    def __init__(self, g_dims, z_dim, condition_dim):
        super(Generator, self).__init__()
        self.dense_layers = nn.Sequential(*[nn.Linear(g_dims[i-1] if i > 0 else z_dim, g_dims[i]) for i in range(len(g_dims[:-1]))])
        self.batch_norm_layers = nn.Sequential(*[nn.BatchNorm1d(dim, eps=1e-5) for dim in g_dims[:-1]])
        self.output_layer = nn.Linear(g_dims[-2], g_dims[-1])
        self.output_sigmoid = nn.Sigmoid()
        self.condition_layers = nn.Sequential(*[nn.Linear(condition_dim, dim) for dim in g_dims[:-1]])
        self.pointwiselayers = nn.Sequential(*[PointWiseLayer(dim) for dim in g_dims[:-1]])

    def forward(self, x, condition):
        for i in range(len(self.dense_layers)):
            h = self.dense_layers[i](x)
            x = nn.functional.relu(self.pointwiselayers[i](self.batch_norm_layers[i](h), self.condition_layers[i](condition)))
        x = self.output_layer(x)
        x = self.output_sigmoid(x)
        return x

class Discriminator(nn.Module):
    def __init__(self, d_dims, g_dims, condition_dim):
        super(Discriminator, self).__init__()
        self.dense_layers = nn.Sequential(*[nn.Linear(d_dims[i-1] if i > 0 else g_dims[-1] + 1, d_dims[i]) for i in range(len(d_dims))])
        self.layer_norm_layers = nn.Sequential(*[nn.LayerNorm(dim, eps=1e-5) for dim in d_dims])
        self.output_layer = nn.Linear(d_dims[-1], 1)
        self.condition_layers = nn.Sequential(*[nn.Linear(condition_dim, dim) for dim in d_dims])
        self.pointwiselayers = nn.Sequential(*[PointWiseLayer(dim) for dim in d_dims])

    def forward(self, x, condition):
        a = (2 * x) ** 15
        sparsity = torch.sum(a / (a + 1), axis=-1, keepdim=True)
        x = torch.cat((x, sparsity), axis=-1)
        for i in range(len(self.dense_layers)):
            h = self.dense_layers[i](x)
            x = self.pointwiselayers[i](self.layer_norm_layers[i](h), self.condition_layers[i](condition))
        x = self.output_layer(x)
        return x
      
class BuildModel(nn.Module):
    def __init__(self,
        emb_size,
        hidden_dim,
        condition_dim,
        n_head,
        n_rnn_layer,
        total_vocab_size,
        z_dim,
        g_dims,
        d_dims,
        device,
        max_visit,
        max_code_per_visit,
        **kwargs,
        ) -> None:
        super().__init__()
        
        self.device = device
        self.total_vocab_size = total_vocab_size
        self.max_visit = max_visit
        self.max_code_per_visit = max_code_per_visit
        self.z_dim = z_dim
        
        self.dependency_module = DependencyModel(
            emb_size=emb_size,
            hidden_dim=hidden_dim,
            condition_dim=condition_dim,
            total_vocab_size=total_vocab_size,
            n_head=n_head,
            n_rnn_layer=n_rnn_layer
        )

        self.generator_module = Generator(
            z_dim=z_dim,
            g_dims=g_dims,
            condition_dim=condition_dim
        )
        
        self.discriminator_module = Discriminator(
            d_dims=d_dims,
            g_dims=g_dims,
            condition_dim=condition_dim
        )
        
    def dependency_forward(self, inputs, export=False):
        logits = self.dependency_module(inputs['v'], inputs['c_lengths'], export)
        if export:
            return logits
        
        labels = torch.sum(nn.functional.one_hot(inputs['v'].long(), num_classes=self.total_vocab_size+2), dim=2).float()[:, 1:, :-2]
        for idx in range(len(inputs['v_lengths'])):
            logits[idx, inputs['v_lengths'][idx]:] = 0
        
        return logits, labels
    
    def gan_forward(self, inputs):
        x_real = torch.sum(nn.functional.one_hot(inputs['v'].long(), num_classes=self.total_vocab_size+2), dim=1).float()[:, :-2]
        z = torch.randn((len(x_real), self.z_dim)).to(self.device)
        x_fake = self.generator_module(z, inputs['conditions'])
        y_fake = self.discriminator_module(x_fake, inputs['conditions'])
        y_real = self.discriminator_module(x_real, inputs['conditions'])
        alpha = torch.rand((x_real.size(0), 1)).to(self.device)
        interpolates = (alpha * x_real + (1-alpha)*x_fake).requires_grad_(True).float()
        d_interpolates = self.discriminator_module(interpolates, inputs['conditions'])
        return {
            'y_fake': y_fake,
            'y_real': y_real,
            'interpolates': interpolates,
            'd_interpolates': d_interpolates
        }
    
    def forward(self, inputs, export=False):
        if 'conditions' in inputs:
            return self.gan_forward(inputs)
        else:
            return self.dependency_forward(inputs, export)
        
    def add_condition(self, inputs):
        condition = self.dependency_forward(inputs, True)
        visits = inputs['v'][:,1:]
        v_new = []
        condition_new = []
        for i in range(len(inputs['v_lengths'])):
            for j in range(inputs['v_lengths'][i]):
                v_new.append(visits[i, j, :])
                condition_new.append(condition[i,j,:])
        
        inputs['v'] = torch.stack(v_new)
        inputs['conditions'] = torch.stack(condition_new)
        return inputs
      
    def sample(self, n_samples):
        ehr = torch.zeros((n_samples, 1, self.total_vocab_size), device=self.device)
        batch_ehr = torch.ones((n_samples, self.max_visit+1, self.max_code_per_visit+1), device=self.device, dtype=torch.int) * (self.total_vocab_size + 1)
        batch_ehr[:,:,0] = self.total_vocab_size
        batch_lens = torch.zeros((n_samples, self.max_visit+1, 1), dtype=torch.int, device=self.device)
        batch_lens[:,0,0] = 1
        with torch.no_grad():
            for j in trange(self.max_visit):
                for i in range(n_samples):
                    codes = torch.nonzero(ehr[i,j]).squeeze(1)
                    batch_ehr[i,j,:min(len(codes), self.max_code_per_visit)] = codes[1:min(len(codes), self.max_code_per_visit) + 1]
                    batch_lens[i,j] = 1 + min(len(codes), self.max_code_per_visit)
                
                condition_vector = self.dependency_module(batch_ehr, batch_lens, export=True)
                condition = condition_vector[:,j,:]
                z = torch.randn((n_samples, self.z_dim), device=self.device)
                visit = self.generator_module(z, condition)
                visit = torch.bernoulli(visit).unsqueeze(1)
                ehr = torch.cat((ehr, visit), dim=1)
        return ehr[:,1:]
    

[docs]class SynTEG(SequenceSimulationBase): ''' Implement a GAN based model for longitudinal patient records simulation [1]_. Parameters ---------- vocab_size: list[int] A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication. order: list[str] The order of event types in each visits, e.g., ``['diag', 'prod', 'med']``. Visit = [diag_events, prod_events, med_events], each event is a list of codes. max_visit: int Maximum number of visits. max_code_per_visit: int Maximum number of medical codes in a single visit. emb_size: int Embedding size for encoding input event codes. hidden_dim: int Size of intermediate hidden dimension for RNN and Feed Forward layers condition_dim: int Size of intermediate dimension for encoding medical history to condition the GAN n_head: int Number of attention heads n_rnn_layer: int Number of RNN layers for encoding historical events. z_dim: int Dimension of noise vector passed into the GAN Generator g_dims: list List of ints for intermediate GAN Generator dimensionalities d_dims: list List of ints for intermediate GAN Discriminator dimensionalities learning_rate: float Learning rate for optimization based on SGD. Use torch.optim.Adam by default. batch_size: int Batch size when doing SGD optimization. epochs: int Maximum number of iterations taken for the solvers to converge. num_worker: int Number of workers used to do dataloading during training. Notes ----- .. [1] Zhang, Ziqi, et al. (2021, March). SynTEG: a framework for temporal structured electronic health data simulation. Journal of the American Medical Informatics Association 28.3. ''' def __init__(self, vocab_size, order, max_visit=20, max_code_per_visit=20, emb_size=64, hidden_dim=32, condition_dim=32, n_head=4, n_rnn_layer=2, z_dim=64, g_dims=[32, 32, 64, 64], d_dims=[64, 32, 32], learning_rate=1e-3, batch_size=64, epochs=20, num_worker=0, device='cpu',# 'cuda:0', experiment_id='trial_simulation.sequence.synteg', ): super().__init__(experiment_id) self.config = { 'vocab_size': vocab_size, 'orders': order, 'max_visit': max_visit, 'max_code_per_visit': max_code_per_visit, 'emb_size': emb_size, 'hidden_dim': hidden_dim, 'condition_dim': condition_dim, 'n_head': n_head, 'n_rnn_layer': n_rnn_layer, 'z_dim': z_dim, 'g_dims': g_dims, 'd_dims': d_dims, 'learning_rate': learning_rate, 'batch_size': batch_size, 'epochs': epochs, 'num_worker': num_worker, 'device':device, } self.config['total_vocab_size'] = sum(vocab_size) self.config['g_dims'].append(self.config['total_vocab_size']) self.device = device self._build_model() def _build_model(self): self.model = BuildModel( emb_size = self.config['emb_size'], hidden_dim = self.config['hidden_dim'], condition_dim = self.config['condition_dim'], n_head = self.config['n_head'], n_rnn_layer = self.config['n_rnn_layer'], total_vocab_size = self.config['total_vocab_size'], z_dim = self.config['z_dim'], g_dims = self.config['g_dims'], d_dims = self.config['d_dims'], device = self.device, max_visit = self.config['max_visit'], max_code_per_visit = self.config['max_code_per_visit'] ) self.model = self.model.to(self.device)
[docs] def fit(self, train_data): ''' Train model with sequential patient records. Parameters ---------- train_data: SequencePatientBase A `SequencePatientBase` contains patient records where 'v' corresponds to visit sequence of different events. ''' self._input_data_check(train_data) self._fit_model(train_data)
[docs] def predict(self, n, return_tensor=True): ''' Generate synthetic records Parameters ---------- n: int How many samples in total will be generated. return_tensor: bool If `True`, return output generated records in tensor format (n, n_visit, n_event), good for later predictive modeling. If `False, return records in `SequencePatient` format. ''' assert isinstance(n, int), 'Input `n` should be integer.' outputs = self._predict(n) if not return_tensor: outputs = self._translate_sparse_visits_to_dense(outputs) return outputs
[docs] def save_model(self, output_dir): ''' Save the learned simulation model to the disk. Parameters ---------- output_dir: str or None The dir to save the learned model. If set None, will save model to `self.checkout_dir`. ''' if output_dir is not None: make_dir_if_not_exist(output_dir) else: output_dir = self.checkout_dir self._save_config(self.config, output_dir=output_dir) self._save_checkpoint({ 'dependency': self.model.dependency_module.state_dict(), 'generator': self.model.generator_module.state_dict(), 'discriminator': self.model.discriminator_module.state_dict() }, output_dir=output_dir)
[docs] def load_model(self, checkpoint): ''' Load model and the pre-encoded trial embeddings from the given checkpoint dir. Parameters ---------- checkpoint: str The input dir that stores the pretrained model. - If a directory, the only checkpoint file `*.pth.tar` will be loaded. - If a filepath, will load from this file. ''' checkpoint_filename = check_checkpoint_file(checkpoint) config_filename = check_model_config_file(checkpoint) state_dict = torch.load(checkpoint_filename, map_location='cpu') if config_filename is not None: config = self._load_config(config_filename) self.config.update(config) self.model.dependency_module.load_state_dict(state_dict['dependency']) self.model.generator_module.load_state_dict(state_dict['generator']) self.model.discriminator_module.load_state_dict(state_dict['discriminator'])
def get_train_dataloader(self, train_data): dataloader = DataLoader(train_data, batch_size=self.config['batch_size'], num_workers=self.config['num_worker'], pin_memory=True, shuffle=True, collate_fn=SeqPatientCollator( config={ 'visit_mode':train_data.metadata['visit']['mode'], 'label_mode':train_data.metadata['label']['mode'], } ), ) return dataloader def get_test_dataloader(self, test_data): dataloader = DataLoader(test_data, batch_size=self.config['batch_size'], num_workers=self.config['num_worker'], pin_memory=False, shuffle=False, collate_fn=SeqPatientCollator( config={ 'visit_mode':test_data.metadata['visit']['mode'], 'label_mode':test_data.metadata['label']['mode'], } ), ) return dataloader def _build_simulation_loss_model(self): return [MultilabelBinaryXentLoss(self.model)] def _build_gan_loss_model(self): return [DiscriminatorLossGP(self.model), GeneratorLoss(self.model)] def _fit_model(self, train_data): # PHASE 1: Simulation train_dataloader = self.get_train_dataloader(train_data) simulation_loss_models = self._build_simulation_loss_model() simulation_train_objectives = [(train_dataloader, loss_model) for loss_model in simulation_loss_models] trainer = SeqSimSyntegTrainer( model=self, train_objectives=simulation_train_objectives ) trainer.train(**self.config) # PHASE 2: GAN gan_loss_models = self._build_gan_loss_model() gan_train_objectives = [(train_dataloader, loss_model) for loss_model in gan_loss_models] trainer = SeqSimSyntegTrainer( model=self, train_objectives=gan_train_objectives, condition=True ) trainer.train(**self.config) @torch.no_grad() def _predict(self, n): return self.model.sample(n) def _translate_sparse_visits_to_dense(self, visits): def _map_func(x): res = np.where(x > 0)[0].tolist() return [0] if len(res) == 0 else res # pad if nothing happened outputs = defaultdict(list) for batchv in visits: voc_offset = 0 for i, o in enumerate(self.config['orders']): voc_size = self.config['vocab_size'][i] visit = batchv[...,voc_offset:voc_offset+voc_size] visit = visit.cpu().numpy() res = list(map(_map_func, visit)) outputs[o].append(res) voc_offset += voc_size n_total = len(outputs[o]) sample_list = [] for i in range(n_total): sample = [] for numv in range(len(outputs[o][i])): visit = [] for o in self.config['orders']: visit.append(outputs[o][i][numv]) sample.append(visit) sample_list.append(sample) # create seqpatient data return SequencePatient( data={'v':sample_list}, metadata={ 'visit':{'mode':'dense','order':self.config['orders']}, } ) def _translate_dense_visits_to_sparse(self, visits): total_vocab_size = sum(self.config['vocab_size']) num_visits = len(visits[self.config['orders'][0]]) outputs = np.zeros((num_visits, total_vocab_size)) for i, o in enumerate(self.config['orders']): for j in range(num_visits): raw = visits[o][j] if isinstance(raw, torch.Tensor): raw = raw.detach().cpu().numpy() if i > 0: voc_size = sum(self.config['vocab_size'][:i-1]) if isinstance(raw, list): raw = [r + voc_size for r in raw] else: raw += voc_size outputs[j, raw] = 1 return outputs def _pad_multiple_tensor_visits(self, visits): new_list = [] for v in visits: new_list.extend([torch.tensor(x).squeeze(0) for x in np.array_split(v, len(v))]) return pad_sequence(new_list, batch_first=True) def _prepare_input(self, data): ''' Prepare inputs for sequence simulation models. Parameters ---------- data: dict[list] A batch of patient records. ''' visits = data['v'] feature = data['x'] if not isinstance(feature, torch.Tensor): feature = torch.tensor(feature) feature = feature.to(self.device) inputs = { 'v':{}, 'c_lengths':{}, 'v_lengths':[], 'x':feature, # baseline feature } v_lengths = [len(visits[self.config['orders'][0]][idx][:self.config['max_visit']]) for idx in range(len(visits[self.config['orders'][0]]))] inputs['v_lengths'] = v_lengths v = torch.ones(len(v_lengths), max(v_lengths) + 1, self.config['max_code_per_visit'] + 1, dtype=torch.int) * (self.config['total_vocab_size'] + 1) c_lengths = torch.ones(len(v_lengths), max(v_lengths) + 1, 1, dtype=torch.int) for i in range(len(v_lengths)): for j in range(min(v_lengths[i], self.config['max_visit'])): visit = torch.IntTensor([r + sum(self.config['vocab_size'][:n-1]) if n > 0 else r for n, o in enumerate(self.config['orders']) for r in visits[o][i][j]][:self.config['max_code_per_visit']]) v[i, j, 1:len(visit)+1] = visit c_lengths[i, j] = len(visit) + 1 v[:,:,0] = self.config['total_vocab_size'] c_lengths[:,0] = 1 v = v.to(self.device) c_lengths = c_lengths.to(self.device) inputs['v'] = v inputs['c_lengths'] = c_lengths return inputs def _add_condition(self, inputs): return self.model.add_condition(inputs) def _input_data_check(self, inputs): assert isinstance(inputs, SequencePatientBase), f'`trial_simulation.sequence` models require input training data in `SequencePatientBase`, find {type(inputs)} instead.'