Source code for pytrial.tasks.trial_outcome.hint

import pickle
from copy import deepcopy 
import os
import json
import joblib

from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, precision_score, recall_score, accuracy_score
import matplotlib.pyplot as plt
import numpy as np 
from tqdm import tqdm 
import torch 
from torch import nn 
import torch.nn.functional as F
from functools import reduce

from .model_utils.module import Highway, GCN 
from .model_utils.molecule_encode import MPNN, ADMET 
from .model_utils.icdcode_encode import GRAM, build_icdcode2ancestor_dict
from .model_utils.protocol_encode import Protocol_Embedding
from .model_utils.utils import smiles_lst2fp
from .model_utils.utils import plot_hist
from .model_utils.utils import replace_strange_symbol
from .model_utils.utils import trial_collate_fn
from .base import TrialOutcomeBase

class Interaction(nn.Sequential, TrialOutcomeBase):
    def __init__(self, 
        disease_embedding_dim, 
        protocol_output_dim, 
        molecule_embedding_dim, 
        device, 
        global_embed_size,
        highway_num_layer,
        prefix_name, 
        epoch = 20,
        lr = 3e-4, 
        weight_decay = 0, 
        ):
        super(Interaction, self).__init__()
        icdcode2ancestor_dict = build_icdcode2ancestor_dict()
        self.disease_encoder = GRAM(embedding_dim = disease_embedding_dim, icdcode2ancestor = icdcode2ancestor_dict, device = device)
        self.protocol_encoder = Protocol_Embedding(output_dim = protocol_output_dim, highway_num=3, device = device)
        self.molecule_encoder = MPNN(mpnn_hidden_size = molecule_embedding_dim, mpnn_depth=3, device = device)
        self.global_embed_size = global_embed_size 
        self.highway_num_layer = highway_num_layer 
        self.feature_dim = self.molecule_encoder.embedding_size + self.disease_encoder.embedding_size + self.protocol_encoder.embedding_size
        self.epoch = epoch 
        self.lr = lr 
        self.weight_decay = weight_decay 
        self.save_name = prefix_name + '_interaction'

        self.f = F.relu
        self.loss = nn.BCEWithLogitsLoss()

        ##### NN 
        self.encoder2interaction_fc = nn.Linear(self.feature_dim, self.global_embed_size).to(device)
        self.encoder2interaction_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device)
        self.pred_nn = nn.Linear(self.global_embed_size, 1)

        self.device = device 
        self = self.to(device)

    def feed_lst_of_module(self, input_feature, lst_of_module):
        x = input_feature
        for single_module in lst_of_module:
            x = self.f(single_module(x))
        return x

    def forward_get_three_encoders(self, smiles_lst2, icdcode_lst3, criteria_lst):
        molecule_embed = self.molecule_encoder.forward_smiles_lst_lst(smiles_lst2)
        icd_embed = self.disease_encoder.forward_code_lst3(icdcode_lst3)
        protocol_embed = self.protocol_encoder.forward(criteria_lst)
        return molecule_embed, icd_embed, protocol_embed	

    def forward_encoder_2_interaction(self, molecule_embed, icd_embed, protocol_embed):
        encoder_embedding = torch.cat([molecule_embed, icd_embed, protocol_embed], 1)
        # interaction_embedding = self.feed_lst_of_module(encoder_embedding, [self.encoder2interaction_fc, self.encoder2interaction_highway])
        h = self.encoder2interaction_fc(encoder_embedding)
        h = self.f(h)
        h = self.encoder2interaction_highway(h)
        interaction_embedding = self.f(h)
        return interaction_embedding 

    def forward(self, smiles_lst2, icdcode_lst3, criteria_lst):
        molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst)
        interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed)
        output = self.pred_nn(interaction_embedding)
        return output ### 32, 1

    def evaluation(self, predict_all, label_all, threshold = 0.5):
        from sklearn.metrics import roc_curve, precision_recall_curve
        with open("predict_label.txt", 'w') as fout:
            for i,j in zip(predict_all, label_all):
                fout.write(str(i)[:4] + '\t' + str(j)[:4]+'\n')
        auc_score = roc_auc_score(label_all, predict_all)
        figure_folder = "figure"
        #### ROC-curve 
        fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1)
        # roc_curve =plt.figure()
        # plt.plot(fpr,tpr,'-',label=self.save_name + ' ROC Curve ')
        # plt.legend(fontsize = 15)
        #plt.savefig(os.path.join(figure_folder,name+"_roc_curve.png"))
        #### PR-curve
        precision, recall, thresholds = precision_recall_curve(label_all, predict_all)
        # plt.plot(recall,precision, label = self.save_name + ' PR Curve')
        # plt.legend(fontsize = 15)
        # plt.savefig(os.path.join(figure_folder,self.save_name + "_pr_curve.png"))
        label_all = [int(i) for i in label_all]
        float2binary = lambda x:0 if x<threshold else 1
        predict_all = list(map(float2binary, predict_all))
        f1score = f1_score(label_all, predict_all)
        prauc_score = average_precision_score(label_all, predict_all)
        # print(predict_all)
        precision = precision_score(label_all, predict_all)
        recall = recall_score(label_all, predict_all)
        accuracy = accuracy_score(label_all, predict_all)
        predict_1_ratio = sum(predict_all) / len(predict_all)
        label_1_ratio = sum(label_all) / len(label_all)
        return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 

    def testloader_to_lst(self, dataloader):
        nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst = [], [], [], [], []
        for nctid, label, smiles, icdcode, criteria in dataloader:
            nctid_lst.extend(nctid)
            label_lst.extend([i.item() for i in label])
            smiles_lst2.extend(smiles)
            icdcode_lst3.extend(icdcode)
            criteria_lst.extend(criteria)
        length = len(nctid_lst)
        assert length == len(smiles_lst2) and length == len(icdcode_lst3)
        return nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst, length 

    def generate_predict(self, dataloader):
        whole_loss = 0 
        label_all, predict_all, nctid_all = [], [], []
        for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader:
            nctid_all.extend(nctid_lst)
            label_vec = label_vec.to(self.device)
            output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst).view(-1)  
            loss = self.loss(output, label_vec.float())
            whole_loss += loss.item()
            predict_all.extend([i.item() for i in torch.sigmoid(output)])
            label_all.extend([i.item() for i in label_vec])

        return whole_loss, predict_all, label_all, nctid_all

    def bootstrap_test(self, dataloader, sample_num = 20):
        # if validloader is not None:
        # 	best_threshold = self.select_threshold_for_binary(validloader)
        self.eval()
        best_threshold = 0.5 
        whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader)
        plt.clf()
        prefix_name = "./figure/" + self.save_name 
        plot_hist(prefix_name, predict_all, label_all)		
        def bootstrap(length, sample_num):
            idx = [i for i in range(length)]
            from random import choices 
            bootstrap_idx = [choices(idx, k = length) for i in range(sample_num)]
            return bootstrap_idx 
        results_lst = []
        bootstrap_idx_lst = bootstrap(len(predict_all), sample_num = sample_num)
        for bootstrap_idx in bootstrap_idx_lst: 
            bootstrap_label = [label_all[idx] for idx in bootstrap_idx]		
            bootstrap_predict = [predict_all[idx] for idx in bootstrap_idx]
            results = self.evaluation(bootstrap_predict, bootstrap_label, threshold = best_threshold)
            results_lst.append(results)
        self.train() 
        auc = [results[0] for results in results_lst]
        f1score = [results[1] for results in results_lst]
        prauc_score = [results[2] for results in results_lst]
        print("PR-AUC   mean: "+str(np.mean(prauc_score))[:6], "std: "+str(np.std(prauc_score))[:6])
        print("F1       mean: "+str(np.mean(f1score))[:6], "std: "+str(np.std(f1score))[:6])
        print("ROC-AUC  mean: "+ str(np.mean(auc))[:6], "std: " + str(np.std(auc))[:6])

        for nctid, label, predict in zip(nctid_all, label_all, predict_all):
            if (predict > 0.5 and label == 0) or (predict < 0.5 and label == 1):
                print(nctid, label, str(predict)[:5])

        nctid2predict = {nctid:predict for nctid, predict in zip(nctid_all, predict_all)} 
        pickle.dump(nctid2predict, open('results/nctid2predict.pkl', 'wb'))
        return nctid_all, predict_all 

    def ongoing_test(self, dataloader, sample_num = 20):
        self.eval()
        best_threshold = 0.5 
        whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader) 
        self.train() 
        return nctid_all, predict_all

    def test(self, dataloader, return_loss = True, validloader=None):
        # if validloader is not None:
        # 	best_threshold = self.select_threshold_for_binary(validloader)
        self.eval()
        best_threshold = 0.5 
        whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader)
        # from HINT.utils import plot_hist
        # plt.clf()
        # prefix_name = "./figure/" + self.save_name 
        # plot_hist(prefix_name, predict_all, label_all)
        self.train()
        if return_loss:
            return whole_loss
        else:
            print_num = 5
            auc_score, f1score, prauc_score, precision, recall, accuracy, \
            predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold)
            print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \
                 + "\nPR-AUC: " + str(prauc_score)[:print_num] \
                 + "\nPrecision: " + str(precision)[:print_num] \
                 + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \
                 + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \
                 + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num])
            return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio

    def plot_learning_curve(self, train_loss_record, valid_loss_record):
        plt.plot(train_loss_record)
        plt.savefig("./figure/" + self.save_name + '_train_loss.jpg')
        plt.clf() 
        plt.plot(valid_loss_record)
        plt.savefig("./figure/" + self.save_name + '_valid_loss.jpg')
        plt.clf() 

    def select_threshold_for_binary(self, validloader):
        _, prediction, label_all, nctid_all = self.generate_predict(validloader)
        best_f1 = 0
        for threshold in prediction:
            float2binary = lambda x:0 if x<threshold else 1
            predict_all = list(map(float2binary, prediction))
            f1score = precision_score(label_all, predict_all)        
            if f1score > best_f1:
                best_f1 = f1score 
                best_threshold = threshold
        return best_threshold 



class HINT_nograph(Interaction):
    def __init__(self, 
                disease_embedding_dim, 
                protocol_output_dim, 
                molecule_embedding_dim, 		
                device, 
                global_embed_size, 
                highway_num_layer,
                prefix_name, 
                epoch = 20,
                lr = 3e-4, 
                weight_decay = 0, ):
        super(HINT_nograph, self).__init__(					
            disease_embedding_dim = disease_embedding_dim, 
            protocol_output_dim = protocol_output_dim, 
            molecule_embedding_dim = molecule_embedding_dim, 
            device = device,  
            global_embed_size = global_embed_size, 
            prefix_name = prefix_name, 
            highway_num_layer = highway_num_layer,
            epoch = epoch,
            lr = lr, 
            weight_decay = weight_decay, 
            ) 
        self.save_name = prefix_name + '_HINT_nograph'
        '''	### interaction model 
        self.molecule_encoder = molecule_encoder 
        self.disease_encoder = disease_encoder 
        self.protocol_encoder = protocol_encoder 
        self.global_embed_size = global_embed_size 
        self.highway_num_layer = highway_num_layer 
        self.feature_dim = self.molecule_encoder.embedding_size + self.disease_encoder.embedding_size + self.protocol_encoder.embedding_size
        self.epoch = epoch 
        self.lr = lr 
        self.weight_decay = weight_decay 
        self.save_name = save_name

        self.f = F.relu
        self.loss = nn.BCEWithLogitsLoss()

        ##### NN 
        self.encoder2interaction_fc = nn.Linear(self.feature_dim, self.global_embed_size)
        self.encoder2interaction_highway = Highway(self.global_embed_size, self.highway_num_layer)
        self.pred_nn = nn.Linear(self.global_embed_size, 1)
        '''


        #### risk of disease 
        self.risk_disease_fc = nn.Linear(self.disease_encoder.embedding_size, self.global_embed_size)
        self.risk_disease_higway = Highway(self.global_embed_size, self.highway_num_layer)

        #### augment interaction 
        self.augment_interaction_fc = nn.Linear(self.global_embed_size*2, self.global_embed_size)
        self.augment_interaction_highway = Highway(self.global_embed_size, self.highway_num_layer)

        #### ADMET 
        self.admet_model = []
        for i in range(5):
            admet_fc = nn.Linear(self.molecule_encoder.embedding_size, self.global_embed_size).to(device)
            admet_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device)
            self.admet_model.append(nn.ModuleList([admet_fc, admet_highway])) 
        self.admet_model = nn.ModuleList(self.admet_model)

        #### PK 
        self.pk_fc = nn.Linear(self.global_embed_size*5, self.global_embed_size)
        self.pk_highway = Highway(self.global_embed_size, self.highway_num_layer)

        #### trial node 
        self.trial_fc = nn.Linear(self.global_embed_size*2, self.global_embed_size)
        self.trial_highway = Highway(self.global_embed_size, self.highway_num_layer)

        ## self.pred_nn = nn.Linear(self.global_embed_size, 1)

        self.device = device 
        self = self.to(device)


    def forward(self, smiles_lst2, icdcode_lst3, criteria_lst, if_gnn = False):
        ### encoder for molecule, disease and protocol
        molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst)
        ### interaction 
        interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed)
        ### risk of disease 
        risk_of_disease_embedding = self.feed_lst_of_module(input_feature = icd_embed, 
                                                            lst_of_module = [self.risk_disease_fc, self.risk_disease_higway])
        ### augment interaction   
        augment_interaction_input = torch.cat([interaction_embedding, risk_of_disease_embedding], 1)
        augment_interaction_embedding = self.feed_lst_of_module(input_feature = augment_interaction_input, 
                                                                lst_of_module = [self.augment_interaction_fc, self.augment_interaction_highway])
        ### admet 
        admet_embedding_lst = []
        for idx in range(5):
            admet_embedding = self.feed_lst_of_module(input_feature = molecule_embed, 
                                                      lst_of_module = self.admet_model[idx])
            admet_embedding_lst.append(admet_embedding)
        ### pk 
        pk_input = torch.cat(admet_embedding_lst, 1)
        pk_embedding = self.feed_lst_of_module(input_feature = pk_input, 
                                               lst_of_module = [self.pk_fc, self.pk_highway])
        ### trial 
        trial_input = torch.cat([pk_embedding, augment_interaction_embedding], 1)
        trial_embedding = self.feed_lst_of_module(input_feature = trial_input, 
                                                  lst_of_module = [self.trial_fc, self.trial_highway])
        output = self.pred_nn(trial_embedding)
        if if_gnn == False:
            return output 
        else:
            embedding_lst = [molecule_embed, icd_embed, protocol_embed, interaction_embedding, risk_of_disease_embedding, \
                             augment_interaction_embedding] + admet_embedding_lst + [pk_embedding, trial_embedding]
            return embedding_lst

[docs]class HINT(HINT_nograph): ''' Implement Hierarchical Interaction Network (HINT) model for clinical trial outcome prediction [1]_. Parameters ---------- disease_embedding_dim: int dimension of disease code embedding, e.g., 50 protocol_output_dim: int dimension of protocol (eligibility criteria) embedding, e.g., 50 molecule_embedding_dim: int dimension of molecule embedding, e.g., 50 global_embed_size: int dimension of trial component embedding, e.g., 50 highway_num_layer: int number of highway layers, e.g., 3 gnn_hidden_size: int dimension of GNN hidden size, e.g., 50 epoch: int epoch number during training, e.g., 5 lr: float learning rate of optimizer (we use Adam) during training, e.g., 3e-4, batch_size: int batch size during training, e.g., 32 weight_decay: float weight decay coefficient, e.g., 0. prefix_name: str name of trial phase as prefix name of the model, e.g., `phase_I`, `phase_II` device: str or torch.device Target device to train the model, as `cuda:0` or `cpu`. Notes ----- .. [1] Fu et al. HINT: Hierarchical Interaction Network for Clinical Trial Outcome Prediction. Cell Patterns, 2022. ''' def __init__(self, disease_embedding_dim=50, protocol_output_dim=50, molecule_embedding_dim=50, global_embed_size=50, highway_num_layer=3, gnn_hidden_size=50, epoch = 20, lr = 3e-4, batch_size = 32, weight_decay = 0, prefix_name='phase_I', device='cuda:0', ): super(HINT, self).__init__( disease_embedding_dim = disease_embedding_dim, protocol_output_dim = protocol_output_dim, molecule_embedding_dim = molecule_embedding_dim, device = device, prefix_name = prefix_name, global_embed_size = global_embed_size, highway_num_layer = highway_num_layer, epoch = epoch, lr = lr, weight_decay = weight_decay) self.save_name = prefix_name self.gnn_hidden_size = gnn_hidden_size #### GNN self.adj = self.generate_adj() self.gnn = GCN( nfeat = self.global_embed_size, nhid = self.gnn_hidden_size, nclass = 1, dropout = 0.6, init = 'uniform') ### gnn's attention self.node_size = self.adj.shape[0] self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() if self.adj[i,j]==1 else None for j in range(self.node_size)]) for i in range(self.node_size)]) self.device = device self = self.to(device) self.config = { 'disease_embedding_dim': disease_embedding_dim, 'protocol_output_dim': protocol_output_dim, 'molecule_embedding_dim': molecule_embedding_dim, 'global_embed_size': global_embed_size, 'highway_num_layer': highway_num_layer, 'gnn_hidden_size': gnn_hidden_size, 'epoch': epoch, 'lr': lr, 'batch_size': batch_size, 'weight_decay': weight_decay, 'prefix_name': prefix_name, }
[docs] def predict(self, test_data): ''' Make trial outcome prediction for test data. Parameters ---------- test_data: TrialOutcomeDatasetBase Testing data, should be a `TrialOutcomeDatasetBase` object. ''' # build dataloader using test_data testloader = self._build_dataloader_from_dataset(test_data, num_workers=0, batch_size=self.config['batch_size'], shuffle=False, collate_fn=trial_collate_fn) self.eval() # best_threshold = 0.5 whole_loss, predict_all, label_all, nctid_all = self.generate_predict(testloader) predict_result = list(zip(nctid_all, predict_all)) self.train() return predict_result
[docs] def fit(self, train_data, valid_data=None): ''' Train HINT model to predict clinical trial outcome (approval rate) Parameters ---------- train_data: TrialOutcomeDatasetBase Training data, should be a `TrialOutcomeDatasetBase` object. valid_data: TrialOutcomeDatasetBase Validation data, should be a `TrialOutcomeDatasetBase` object. ''' # build dataloader using train_data train_loader = self._build_dataloader_from_dataset(train_data, num_workers=0, batch_size=self.config['batch_size'], shuffle=True, collate_fn=trial_collate_fn) if valid_data is not None: valid_loader = self._build_dataloader_from_dataset(valid_data, num_workers=0, batch_size=self.config['batch_size'], shuffle=False, collate_fn=trial_collate_fn) opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) train_loss_record = [] valid_loss = self.test(valid_loader, return_loss=True) valid_loss_record = [valid_loss] best_valid_loss = valid_loss best_model = deepcopy(self) for ep in tqdm(range(self.epoch)): for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in tqdm(train_loader): label_vec = label_vec.to(self.device) output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst).view(-1) #### 32, 1 -> 32, || label_vec 32, loss = self.loss(output, label_vec.float()) train_loss_record.append(loss.item()) opt.zero_grad() loss.backward() opt.step() print('epoch: {}, loss: {}'.format(ep, loss.item())) if valid_data is not None: # only check valid loss when valid_data is not None valid_loss = self.test(valid_loader, return_loss=True) valid_loss_record.append(valid_loss) if valid_loss < best_valid_loss: print('best valid loss: {} -> {}'.format(best_valid_loss, valid_loss)) best_valid_loss = valid_loss best_model = deepcopy(self) # self.plot_learning_curve(train_loss_record, valid_loss_record) self = deepcopy(best_model)
# auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader)
[docs] def save_model(self, output_dir = None): ''' Save the learned HINT model to the disk. Parameters ---------- output_dir: str or None The output folder to save the learned model. If set None, will save model to `checkpoints/model.ckpt`. ''' if output_dir is None: output_dir = 'checkpoints' if not os.path.exists(output_dir): os.makedirs(output_dir) filename = os.path.join(output_dir, 'model.pkl') # save self using joblib joblib.dump(self, filename) config_filename = os.path.join(output_dir, 'config.json') with open(config_filename, 'w') as f: json.dump(self.config, f)
[docs] def load_model(self, checkpoint=None): ''' Load the learned HINT model from the disk. Parameters ---------- checkpoint: str The checkpoint folder to load the learned model. The checkpoint under this folder should be `model.ckpt`. ''' if checkpoint is None: ckpt_dir = 'checkpoints' checkpoint = os.path.join(ckpt_dir, 'model.pkl') else: checkpoint = os.path.join(checkpoint, 'model.pkl') # load model using joblib model = joblib.load(checkpoint) ckpt_dir = os.path.dirname(checkpoint) config_filename = os.path.join(ckpt_dir, 'config.json') with open(config_filename, 'r') as f: model.config = json.load(f) # replace self with the loaded object self.__dict__.update(model.__dict__)
@staticmethod def from_pretrained(checkpoint=None): ''' Load the learned HINT model from the disk. Parameters ---------- checkpoint: str The checkpoint folder to load the learned model. The checkpoint under this folder should be `model.ckpt`. ''' if checkpoint is None: ckpt_dir = 'checkpoints' checkpoint = os.path.join(ckpt_dir, 'model.pkl') else: checkpoint = os.path.join(checkpoint, 'model.pkl') # load model using joblib self = joblib.load(checkpoint) ckpt_dir = os.path.dirname(checkpoint) config_filename = os.path.join(ckpt_dir, 'config.json') with open(config_filename, 'r') as f: self.config = json.load(f) return self def generate_adj(self): ##### consistent with HINT_nograph.forward lst = ["molecule", "disease", "criteria", 'INTERACTION', 'risk_disease', 'augment_interaction', 'A', 'D', 'M', 'E', 'T', 'PK', "final"] edge_lst = [("disease", "molecule"), ("disease", "criteria"), ("molecule", "criteria"), ("disease", "INTERACTION"), ("molecule", "INTERACTION"), ("criteria", "INTERACTION"), ("disease", "risk_disease"), ('risk_disease', 'augment_interaction'), ('INTERACTION', 'augment_interaction'), ("molecule", "A"), ("molecule", "D"), ("molecule", "M"), ("molecule", "E"), ("molecule", "T"), ('A', 'PK'), ('D', 'PK'), ('M', 'PK'), ('E', 'PK'), ('T', 'PK'), ('augment_interaction', 'final'), ('PK', 'final')] adj = torch.zeros(len(lst), len(lst)) adj = torch.eye(len(lst)) * len(lst) num2str = {k:v for k,v in enumerate(lst)} str2num = {v:k for k,v in enumerate(lst)} for i,j in edge_lst: n1,n2 = str2num[i], str2num[j] adj[n1,n2] = 1 adj[n2,n1] = 1 return adj.to(self.device) def generate_attention_matrx(self, node_feature_mat): attention_mat = torch.zeros(self.node_size, self.node_size).to(self.device) for i in range(self.node_size): for j in range(self.node_size): if self.adj[i,j]!=1: continue feature = torch.cat([node_feature_mat[i].view(1,-1), node_feature_mat[j].view(1,-1)], 1) attention_model = self.graph_attention_model_mat[i][j] attention_mat[i,j] = torch.sigmoid(self.feed_lst_of_module(input_feature=feature, lst_of_module=attention_model)) return attention_mat ##### self.global_embed_size*2 -> 1 def gnn_attention(self): highway_nn = Highway(size = self.global_embed_size*2, num_layers = self.highway_num_layer).to(self.device) highway_fc = nn.Linear(self.global_embed_size*2, 1).to(self.device) return nn.ModuleList([highway_nn, highway_fc]) def forward(self, smiles_lst2, icdcode_lst3, criteria_lst, return_attention_matrix = False): embedding_lst = HINT_nograph.forward(self, smiles_lst2, icdcode_lst3, criteria_lst, if_gnn = True) ### length is 13, each is 32,50 batch_size = embedding_lst[0].shape[0] output_lst = [] if return_attention_matrix: attention_mat_lst = [] for i in range(batch_size): node_feature_lst = [embedding[i].view(1,-1) for embedding in embedding_lst] node_feature_mat = torch.cat(node_feature_lst, 0) ### 13, 50 attention_mat = self.generate_attention_matrx(node_feature_mat) output = self.gnn(node_feature_mat, self.adj * attention_mat) output = output[-1].view(1,-1) output_lst.append(output) if return_attention_matrix: attention_mat_lst.append(attention_mat) output_mat = torch.cat(output_lst, 0) if not return_attention_matrix: return output_mat else: return output_mat, attention_mat_lst def init_pretrain(self, admet_model): self.molecule_encoder = admet_model.molecule_encoder
### generate attention matrix def dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, global_icd): ## label_vec: (n,) y = label_vec num_icd = len(global_icd) fp_lst = [smiles_lst2fp(smiles_lst).reshape(1,-1) for smiles_lst in smiles_lst2] fp_mat = np.concatenate(fp_lst, 0) # fp_mat = torch.from_numpy(fp_mat) ### (n,2048) icdcode_lst = [] for lst2 in icdcode_lst3: lst = list(reduce(lambda x,y:x+y, lst2)) lst = [i.split('.')[0] for i in lst] lst = set(lst) icd_feature = np.zeros((1,num_icd), np.int32) for ele in lst: if ele in global_icd: idx = global_icd.index(ele) icd_feature[0,idx] = 1 icdcode_lst.append(icd_feature) icdcode_mat = np.concatenate(icdcode_lst, 0) X = np.concatenate([fp_mat, icdcode_mat], 1) X = torch.from_numpy(X) X = X.float() # icdcode_mat = torch.from_numpy(icdcode_mat) # X = torch.cat([fp_mat, icdcode_mat], 1) return X, y class FFNN(nn.Sequential): def __init__(self, molecule_dim, diseasecode_dim, global_icd, protocol_dim = 0, prefix_name = 'FFNN', epoch = 10, lr = 3e-4, weight_decay = 0, ): super(FFNN, self).__init__() self.molecule_dim = molecule_dim self.diseasecode_dim = diseasecode_dim self.protocol_dim = protocol_dim self.prefix_name = prefix_name self.epoch = epoch self.lr = lr self.weight_decay = weight_decay self.global_icd = global_icd self.num_icd = len(global_icd) self.fc_dims = [self.molecule_dim + self.diseasecode_dim + self.protocol_dim, 2000, 1000, 200, 50, 1] self.fc_layers = nn.ModuleList([nn.Linear(v,self.fc_dims[i+1]) for i,v in enumerate(self.fc_dims[:-1])]) self.loss = nn.BCEWithLogitsLoss() self.save_name = prefix_name def forward(self, X): for i in range(len(self.fc_layers) - 1): fc_layer = self.fc_layers[i] X = fc_layer(X) last_layer = self.fc_layers[-1] pred = F.sigmoid(last_layer(X)) return pred def fit(self, train_loader, valid_loader, test_loader): opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) train_loss_record = [] valid_loss = self.test(valid_loader, return_loss=True) valid_loss_record = [valid_loss] best_valid_loss = valid_loss best_model = deepcopy(self) for ep in tqdm(range(self.epoch)): for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader: X, _ = dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, self.global_icd) output = self.forward(X).view(-1) #### 32, 1 -> 32, || label_vec 32, loss = self.loss(output, label_vec.float()) train_loss_record.append(loss.item()) opt.zero_grad() loss.backward() opt.step() valid_loss = self.test(valid_loader, return_loss=True) valid_loss_record.append(valid_loss) if valid_loss < best_valid_loss: best_valid_loss = valid_loss best_model = deepcopy(self) self.plot_learning_curve(train_loss_record, valid_loss_record) self = deepcopy(best_model) auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader) def evaluation(self, predict_all, label_all, threshold = 0.5): import pickle, os from sklearn.metrics import roc_curve, precision_recall_curve with open("predict_label.txt", 'w') as fout: for i,j in zip(predict_all, label_all): fout.write(str(i)[:4] + '\t' + str(j)[:4]+'\n') auc_score = roc_auc_score(label_all, predict_all) figure_folder = "figure" #### ROC-curve fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1) # roc_curve =plt.figure() # plt.plot(fpr,tpr,'-',label=self.save_name + ' ROC Curve ') # plt.legend(fontsize = 15) #plt.savefig(os.path.join(figure_folder,name+"_roc_curve.png")) #### PR-curve precision, recall, thresholds = precision_recall_curve(label_all, predict_all) # plt.plot(recall,precision, label = self.save_name + ' PR Curve') # plt.legend(fontsize = 15) # plt.savefig(os.path.join(figure_folder,self.save_name + "_pr_curve.png")) label_all = [int(i) for i in label_all] float2binary = lambda x:0 if x<threshold else 1 predict_all = list(map(float2binary, predict_all)) f1score = f1_score(label_all, predict_all) prauc_score = average_precision_score(label_all, predict_all) # print(predict_all) precision = precision_score(label_all, predict_all) recall = recall_score(label_all, predict_all) accuracy = accuracy_score(label_all, predict_all) predict_1_ratio = sum(predict_all) / len(predict_all) label_1_ratio = sum(label_all) / len(label_all) return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio def generate_predict(self, dataloader): whole_loss = 0 label_all, predict_all = [], [] for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader: X, _ = dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, self.global_icd) output = self.forward(X).view(-1) loss = self.loss(output, label_vec.float()) whole_loss += loss.item() predict_all.extend([i.item() for i in torch.sigmoid(output)]) label_all.extend([i.item() for i in label_vec]) return whole_loss, predict_all, label_all def bootstrap_test(self, dataloader, sample_num = 20): # if validloader is not None: # best_threshold = self.select_threshold_for_binary(validloader) self.eval() best_threshold = 0.5 whole_loss, predict_all, label_all = self.generate_predict(dataloader) plt.clf() prefix_name = "./figure/" + self.save_name plot_hist(prefix_name, predict_all, label_all) def bootstrap(length, sample_num): idx = [i for i in range(length)] from random import choices bootstrap_idx = [choices(idx, k = length) for i in range(sample_num)] return bootstrap_idx results_lst = [] bootstrap_idx_lst = bootstrap(len(predict_all), sample_num = sample_num) for bootstrap_idx in bootstrap_idx_lst: bootstrap_label = [label_all[idx] for idx in bootstrap_idx] bootstrap_predict = [predict_all[idx] for idx in bootstrap_idx] results = self.evaluation(bootstrap_predict, bootstrap_label, threshold = best_threshold) results_lst.append(results) self.train() auc = [results[0] for results in results_lst] f1score = [results[1] for results in results_lst] prauc_score = [results[2] for results in results_lst] print("PR-AUC mean: "+str(np.mean(prauc_score))[:6], "std: "+str(np.std(prauc_score))[:6]) print("F1 mean: "+str(np.mean(f1score))[:6], "std: "+str(np.std(f1score))[:6]) print("ROC-AUC mean: "+ str(np.mean(auc))[:6], "std: " + str(np.std(auc))[:6]) def test(self, dataloader, return_loss = True, validloader=None): # if validloader is not None: # best_threshold = self.select_threshold_for_binary(validloader) self.eval() best_threshold = 0.5 whole_loss, predict_all, label_all = self.generate_predict(dataloader) # from HINT.utils import plot_hist # plt.clf() # prefix_name = "./figure/" + self.save_name # plot_hist(prefix_name, predict_all, label_all) self.train() if return_loss: return whole_loss else: print_num = 5 auc_score, f1score, prauc_score, precision, recall, accuracy, \ predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold) print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \ + "\nPR-AUC: " + str(prauc_score)[:print_num] \ + "\nPrecision: " + str(precision)[:print_num] \ + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \ + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \ + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num]) return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio def plot_learning_curve(self, train_loss_record, valid_loss_record): plt.plot(train_loss_record) plt.savefig("./figure/" + self.save_name + '_train_loss.jpg') plt.clf() plt.plot(valid_loss_record) plt.savefig("./figure/" + self.save_name + '_valid_loss.jpg') plt.clf() class ADMET(nn.Sequential): def __init__(self, mpnn_model, device): super(ADMET, self).__init__() self.num = 5 self.mpnn_model = mpnn_model self.device = device self.mpnn_dim = mpnn_model.mpnn_hidden_size self.admet_model = [] self.global_embed_size = self.mpnn_dim self.highway_num_layer = 2 for i in range(5): admet_fc = nn.Linear(self.mpnn_model.mpnn_hidden_size, self.global_embed_size).to(device) admet_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device) self.admet_model.append(nn.ModuleList([admet_fc, admet_highway])) self.admet_model = nn.ModuleList(self.admet_model) self.admet_pred = nn.ModuleList([nn.Linear(self.global_embed_size,1).to(device) for i in range(5)]) self.f = F.relu self.device = device self = self.to(device) def feed_lst_of_module(self, input_feature, lst_of_module): x = input_feature for single_module in lst_of_module: x = self.f(single_module(x)) return x def forward(self, smiles_lst, idx): assert idx in list(range(5)) ''' xxxxxxxxxxxx ''' embeds = self.mpnn_model.forward_smiles_lst_lst(smiles_lst) embeds = self.feed_lst_of_module(embeds, self.admet_model[idx]) output = self.admet_pred[idx](embeds) return output def test(self, valid_loader): pass # def fit(self, train_loader, valid_loader, idx): # opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) # train_loss_record = [] # valid_loss = self.test(valid_loader, return_loss=True) # valid_loss_record = [valid_loss] # best_valid_loss = valid_loss # best_model = deepcopy(self) # for ep in tqdm(range(self.epoch)): # for smiles_lst in train_loader: # output = self.forward(smiles_lst).view(-1) #### 32, 1 -> 32, || label_vec 32, # loss = self.loss(output, label_vec.float()) # train_loss_record.append(loss.item()) # opt.zero_grad() # loss.backward() # opt.step() # valid_loss = self.test(valid_loader, return_loss=True) # valid_loss_record.append(valid_loss) # if valid_loss < best_valid_loss: # best_valid_loss = valid_loss # best_model = deepcopy(self) # self = deepcopy(best_model)