Source code for pytrial.tasks.trial_simulation.tabular.med_gan

import os
import pdb
import joblib
import warnings

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import BatchNorm1d, Linear, Module, Sequential
from torch.nn.functional import cross_entropy, mse_loss, sigmoid
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

from pytrial.data.patient_data import TabularPatientBase
# from pytrial.utils.check import check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist
from pytrial.utils.tabular_utils import get_transformer
from pytrial.utils.check import check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist


from .base import TabularSimulationBase

class ResidualFC(Module):
    def __init__(self, input_dim, output_dim, activate, bn_decay):
        super(ResidualFC, self).__init__()
        self.seq = Sequential(
            Linear(input_dim, output_dim),
            BatchNorm1d(output_dim, momentum=bn_decay),
            activate()
        )

    def forward(self, input):
        residual = self.seq(input)
        return input + residual


class Generator(Module):
    def __init__(self, random_dim, hidden_dim, bn_decay):
        super(Generator, self).__init__()

        dim = random_dim
        seq = []
        for item in list(hidden_dim)[:-1]:
            assert item == dim
            seq += [ResidualFC(dim, dim, nn.ReLU, bn_decay)]
        assert hidden_dim[-1] == dim
        seq += [
            Linear(dim, dim),
            BatchNorm1d(dim, momentum=bn_decay),
            nn.ReLU()
        ]
        self.seq = Sequential(*seq)

    def forward(self, input):
        return self.seq(input)


class Discriminator(Module):
    def __init__(self, data_dim, hidden_dim):
        super(Discriminator, self).__init__()
        dim = data_dim * 2
        seq = []
        for item in list(hidden_dim):
            seq += [
                Linear(dim, item),
                nn.ReLU() if item > 1 else nn.Sigmoid()
            ]
            dim = item
        self.seq = Sequential(*seq)

    def forward(self, input):
        mean = input.mean(dim=0, keepdim=True)
        mean = mean.expand_as(input)
        inp = torch.cat((input, mean), dim=1)
        return self.seq(inp)


class Encoder(Module):
    def __init__(self, data_dim, compress_dims, embedding_dim):
        super(Encoder, self).__init__()
        dim = data_dim
        seq = []
        for item in list(compress_dims) + [embedding_dim]:
            seq += [
                Linear(dim, item),
                nn.ReLU()
            ]
            dim = item
        self.seq = Sequential(*seq)

    def forward(self, input):
        return self.seq(input)


class Decoder(Module):
    def __init__(self, embedding_dim, decompress_dims, data_dim):
        super(Decoder, self).__init__()
        dim = embedding_dim
        seq = []
        for item in list(decompress_dims):
            seq += [
                Linear(dim, item),
                nn.ReLU()
            ]
            dim = item
        seq.append(Linear(dim, data_dim))
        self.seq = Sequential(*seq)

    def forward(self, input, output_info):
        return self.seq(input)


def aeloss(fake, real, output_info):
    st = 0
    loss = []
    for item in output_info:
        if item[1] == 'sigmoid':
            ed = st + item[0]
            loss.append(mse_loss(sigmoid(fake[:, st:ed]), real[:, st:ed], reduction='sum'))
            st = ed
        elif item[1] == 'softmax':
            ed = st + item[0]
            loss.append(cross_entropy(
                fake[:, st:ed], torch.argmax(real[:, st:ed], dim=-1), reduction='sum'))
            st = ed
        else:
            assert 0
    return sum(loss) / fake.size()[0]


class MedGANSynthesizer:
    """docstring for MedGAN."""

    def __init__(
        self,
        embedding_dim=128,
        random_dim=128,
        generator_dims=(128, 128),          # 128 -> 128 -> 128
        discriminator_dims=(256, 128, 1),   # datadim * 2 -> 256 -> 128 -> 1
        compress_dims=(),                   # datadim -> embedding_dim
        decompress_dims=(),                 # embedding_dim -> datadim
        bn_decay=0.99,
        l2scale=0.001,
        pretrain_epoch=200,
        batch_size=1000,
        epochs=2000,
        device='cpu',
        verbose=False,
        ):

        self.embedding_dim = embedding_dim
        self.random_dim = random_dim
        self.generator_dims = generator_dims
        self.discriminator_dims = discriminator_dims

        self.compress_dims = compress_dims
        self.decompress_dims = decompress_dims
        self.bn_decay = bn_decay
        self.l2scale = l2scale

        self.pretrain_epoch = pretrain_epoch
        self.batch_size = batch_size
        self.epochs = epochs

        self.device = device
        self.verbose = verbose

    def _get_metadata(self, data):
        self.output_info = []
        for k, v in data.metadata['transformed_col2col'].items():
            if len(v) == 1:
                self.output_info.append(
                    (1, 'sigmoid')
                )
            else:
                self.output_info.append(
                    (len(v), 'softmax')
                )

    def fit(self, data):
        data_val = data.df.values
        dataset = TensorDataset(torch.from_numpy(data_val.astype('float32')).to(self.device))
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
        
        self._get_metadata(data)
        
        # data_dim = self.transformer.output_dim
        data_dim = data_val.shape[1]
        encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self.device)
        self.decoder = Decoder(self.embedding_dim, self.compress_dims, data_dim).to(self.device)
        optimizerAE = Adam(
            list(encoder.parameters()) + list(self.decoder.parameters()),
            weight_decay=self.l2scale
        )

        for i in range(self.pretrain_epoch):
            if self.verbose:
                print('Pretrain Epoch: {} / {}'.format(i+1, self.pretrain_epoch))
            for id_, data in enumerate(loader):
                optimizerAE.zero_grad()
                real = data[0].to(self.device)
                emb = encoder(real)
                rec = self.decoder(emb, self.output_info)
                loss = aeloss(rec, real, self.output_info)

                loss.backward()
                optimizerAE.step()

        self.generator = Generator(
            self.random_dim, self.generator_dims, self.bn_decay).to(self.device)
        discriminator = Discriminator(data_dim, self.discriminator_dims).to(self.device)
        optimizerG = Adam(
            list(self.generator.parameters()) + list(self.decoder.parameters()),
            weight_decay=self.l2scale
        )
        optimizerD = Adam(discriminator.parameters(), weight_decay=self.l2scale)

        mean = torch.zeros(self.batch_size, self.random_dim, device=self.device)
        std = mean + 1
        bs_larger_than_n_sample = False
        for i in range(self.epochs):
            n_d = 2
            n_g = 1
            
            if len(loader) < n_d:
                # batch size is larger than all samples
                # so we have to update discriminator each iteration
                bs_larger_than_n_sample = True
                warnings.warn(
                    f"""
                    Batch size is larger than all samples.
                    Discriminator will be updated each iteration instead of each {n_d} iterations.
                    Consider using a smaller batch size.
                    """
                )
            
            for id_, data in enumerate(loader):
                real = data[0].to(self.device)
                noise = torch.normal(mean=mean, std=std)
                emb = self.generator(noise)
                fake = self.decoder(emb, self.output_info)

                optimizerD.zero_grad()
                y_real = discriminator(real)
                y_fake = discriminator(fake)
                real_loss = -(torch.log(y_real + 1e-4).mean())
                fake_loss = (torch.log(1.0 - y_fake + 1e-4).mean())
                loss_d = real_loss - fake_loss
                loss_d.backward()
                optimizerD.step()

                if id_ % n_d == 0 or bs_larger_than_n_sample:
                    for _ in range(n_g):
                        noise = torch.normal(mean=mean, std=std)
                        emb = self.generator(noise)
                        fake = self.decoder(emb, self.output_info)
                        optimizerG.zero_grad()
                        y_fake = discriminator(fake)
                        loss_g = -(torch.log(y_fake + 1e-4).mean())
                        loss_g.backward()
                        optimizerG.step()

            if self.verbose:
                print(f'epoch {i} loss_d: {loss_d.item()} loss_g: {loss_g.item()}')


    def sample(self, n):
        self.generator.eval()
        self.decoder.eval()

        steps = n // self.batch_size + 1
        data = []
        for i in range(steps):
            mean = torch.zeros(self.batch_size, self.random_dim)
            std = mean + 1
            noise = torch.normal(mean=mean, std=std).to(self.device)
            emb = self.generator(noise)
            fake = self.decoder(emb, self.output_info)
            fake = torch.sigmoid(fake)
            data.append(fake.detach().cpu().numpy())
        data = np.concatenate(data, axis=0)
        data = data[:n]
        return data
    
# utils
# ------------

# main functions
# ------------

class BuildModel:
    def __new__(self, config) -> MedGANSynthesizer:
        model = MedGANSynthesizer(
            embedding_dim=config['embedding_dim'],
            random_dim=config['random_dim'],
            generator_dims=config['generator_dims'],
            discriminator_dims=config['discriminator_dims'],
            compress_dims=config['compress_dims'],
            decompress_dims=config['decompress_dims'],
            bn_decay=config['bn_decay'],
            l2scale=config['l2scale'],
            pretrain_epoch=config['pretrain_epoch'],
            epochs=config['epochs'],
            batch_size=config['batch_size'],
            device=config['device'],
            verbose=config['verbose'],
        )
        return model


[docs]class MedGAN(TabularSimulationBase): ''' Implement MedGAN model for patient level tabular data generation [1]_. Parameters ---------- embedding_dim : int, default 128 Dimension of embedding layer. random_dim : int, default 128 Dimension of random noise. generator_dims : tuple, default (128, 128) Dimension of generator layers. discriminator_dims : tuple, default (256, 128, 1) Dimension of discriminator layers. compress_dims : tuple, default () Dimension of compressed embedding layer. datadim -> embedding_dim decompress_dims : tuple, default () Dimension of decompressed embedding layer. embedding_dim -> datadim bn_decay : float, default 0.99 Decay rate of batch normalization. l2scale : float, default 0.001 L2 regularization scale. pretrain_epoch : int, default 200 Number of pretrain epochs. batch_size : int, default 1000 Batch size for training. epochs : int, default 1000 Number of epochs for training. experiment_id: str Experiment id for logging. verbose: bool Whether to print training information. Notes ----- .. [1] Choi, E., Biswal, S., Malin, B., Duke, J., Stewart, W. F., & Sun, J. (2017, November). Generating multi-label discrete patient records using generative adversarial networks. In Machine learning for healthcare conference (pp. 286-305). PMLR. ''' def __init__(self, embedding_dim=128, random_dim=128, generator_dims=(128, 128), # 128 -> 128 -> 128 discriminator_dims=(256, 128, 1), # datadim * 2 -> 256 -> 128 -> 1 compress_dims=(), # datadim -> embedding_dim decompress_dims=(), # embedding_dim -> datadim bn_decay=0.99, l2scale=0.001, pretrain_epoch=200, batch_size=1000, epochs=2000, device='cpu', experiment_id='trial_simulation.tabular.medgan', verbose=False, ): super().__init__(experiment_id=experiment_id) self.config = { 'embedding_dim': embedding_dim, 'random_dim': random_dim, 'generator_dims': generator_dims, 'discriminator_dims': discriminator_dims, 'compress_dims': compress_dims, 'decompress_dims': decompress_dims, 'bn_decay': bn_decay, 'l2scale': l2scale, 'pretrain_epoch': pretrain_epoch, 'batch_size': batch_size, 'epochs': epochs, 'device': device, 'verbose':verbose, } self._save_config(self.config)
[docs] def fit(self, train_data): ''' Train MedGAN model to generate synthetic tabular patient data. Parameters ---------- train_data : TabularPatientBase Training data. ''' self._input_data_check(train_data) self._build_model() self._fit_model(train_data)
[docs] def predict(self, n): ''' Generate synthetic tabular patient data. Parameters ---------- n : int Number of samples to generate. Returns ------- data : np.ndarray Generated synthetic data. ''' return self.model.sample(n)
[docs] def load_model(self, checkpoint): ''' Load model from checkpoint. Parameters ---------- checkpoint : str Path to checkpoint. If a directory is given, will load the latest checkpoint in the directory. If a filepath is given, will load the checkpoint from the filepath. If set None, will load from default directory `self.checkpoint_dir`. ''' if checkpoint is None: checkpoint = self.checkpoint_dir checkpoint_filename = check_checkpoint_file(checkpoint, suffix='model') config_filename = check_model_config_file(checkpoint) self.model = joblib.load(checkpoint_filename) self.config = self._load_config(config_filename)
[docs] def save_model(self, output_dir): ''' Save model to checkpoint. Parameters ---------- output_dir : str Output directory. If set None, will save to default directory `self.checkpoint_dir`. ''' if output_dir is None: output_dir = self.checkpoint_dir make_dir_if_not_exist(output_dir) self._save_config(self.config, output_dir=output_dir) ckpt_path = os.path.join(output_dir, 'ctgan.model') joblib.dump(self.model, ckpt_path)
def _build_model(self): self.model = BuildModel(self.config) def _fit_model(self, dataset): self.model.fit(dataset)