Source code for pytrial.tasks.trial_search.models.whiten_bert

import pdb
import os
import pickle

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch import nn
from datasets import Dataset
from tqdm import tqdm

from pytrial.model_utils.bert import BERT
from pytrial.utils.check import make_dir_if_not_exist
from .base import TrialSearchBase
from .base import whitening_torch_final
from ..metrics import precision, recall, ndcg
from ..data import TrialSearchCollator, batch_to_device


class BuildModel:
    def __new__(self, bert_name):
        return BERT(bert_name)    

[docs]class WhitenBERT(TrialSearchBase): ''' Implement a postprocessing method to improve BERT embeddings for similarity search [1]_. Parameters ---------- layer_mode: {'last_first', 'last''} The mode of layer of embeddings to aggregate. 'last_first' means use the last layer and the first layer. 'last' means use the last layer only. bert_name: str, optional (default = 'emilyalsentzer/Bio_ClinicalBERT') The name of base BERT model used for encoding input texts. device: str, optional (default = 'cuda:0') The device of this model, typically be 'cpu' or 'cuda:0'. experiment_id: str, optional (default = 'test') The name of current experiment. Notes ----- .. [1] Huang, J., Tang, D., Zhong, W., Lu, S., Shou, L., Gong, M., ... & Duan, N. (2021, November). WhiteningBERT: An Easy Unsupervised Sentence Embedding Approach. In Findings of the Association for Computational Linguistics: EMNLP 2021 (pp. 238-244). ''' trial_embs = {} def __init__(self, layer_mode='last_first', bert_name='emilyalsentzer/Bio_ClinicalBERT', device='cuda:0', experiment_id='test'): super().__init__(experiment_id) self.layer_mode = layer_mode self.bert_name = bert_name self.device = device self._build_model() self.config = { 'layer_model': self.layer_mode, 'bert_name': self.bert_name, 'device': self.device, 'batch_size': 32, # default batch size for encoding 'num_workers': 4, # default num_workers for encoding 'max_seq_length': 512, # default max_seq_length for encoding 'tag_field': 'nctid', # default tag_field for encoding } def __getitem__(self, tag): return self.embs[tag]
[docs] def fit(self, train_data, valid_data=None): ''' Go over all trials and encode them into embeddings. Note that this is a post-processing method based on a pretrained BERT model, so it does `NOT` need to be trained. Parameters ---------- train_data: dict The data for encoding. - 'x' is the dataframe that contains multiple sections of a trial. - 'fields' is the list of fields to be encoded. - 'tag' is the unique index column name of each document, e.g., 'nctid'. train_data = { 'x': pd.DataFrame, 'fields': list[str], 'tag': str, } valid_data: Not used. This is a placeholder because this model does not need training. ''' embs = self.encode(train_data, return_dict=True) self.trial_embs = embs return embs
[docs] def predict(self, test_data, top_k=10, return_df=True): ''' Predict the top-k relevant for input documents. Parameters ---------- test_data: dict Share the same input format as the `train_data` in `fit` function. If `fields` and `tag` are not given, will reuse the ones used during training. test_data = { 'x': pd.DataFrame, 'fields': list[str], 'tag': str, } top_k: int Number of retrieved candidates. return_df: float - If set True, return dataframe for the computed similarity ranking. - else, return rank_list=[[(doc1,sim1),(doc2,sim2)], [(doc1,sim1),...]]. Returns ------- rank: pd.DataFrame A dataframe contains the top ranked NCT ids for each. sim: pd.DataFrame A dataframe contains the corresponding similarities. rank_list: list[list[tuple]] A list of tuples of top ranked docs and similarities. ''' self._input_data_check(test_data) tag_field = self.config['tag_field'] assert len(self.trial_embs) > 0, 'No embeddings found. Please run `fit` first. Or load embeddings from `load_model`.' all_embs = np.stack(self.trial_embs.values()) all_tags = np.stack(self.trial_embs.keys()) # TODO: skip encoding those already stored in model # to_encode_test_trial = test_trial[~test_trial[tag_field].isin(all_tags)] tags, embs = self.encode(test_data, return_dict=False) # rank for each test document sim = embs.dot(all_embs.T) rank = np.argsort(sim, 1)[:,::-1] # flip rank = rank[:,1:top_k+1] if return_df: tag_list, sim_list = [], [] for i,row in enumerate(rank): tag_list.append(all_tags[row]) sim_list.append(sim[i][row]) tag_ar = np.array(tag_list) sim_ar = np.array(sim_list) rank_df = pd.DataFrame(tag_ar, columns=[f'rank_{i}' for i in range(top_k)], index=tags) sim_df = pd.DataFrame(sim_ar, columns=[f'rank_{i}' for i in range(top_k)], index=tags) return rank_df, sim_df else: rank_list = [] for i,row in enumerate(rank): tag_ar = all_tags[row] sim_ar = sim[i][row] rank_list.append( list(zip(tag_ar, sim_ar)) ) return rank_list
[docs] def evaluate(self, test_data): ''' Evaluate within the given trial and corresponding candidate trials. Parameters ---------- test_data: dict The provided labeled dataset for test trials. Follow the format listed below. test_data = { 'x': pd.DataFrame, 'y': pd.DataFrame } Returns ------- results: dict[float] A dict of metrics and the values. Notes ----- x = | target_trial | trial1 | trial2 | trial3 | | nct01 | nct02 | nct03 | nct04 | y = | label1 | label2 | label3 | | 0 | 0 | 1 | ''' test_df = test_data['x'] label_df = test_data['y'] ranked_label_list = [] for idx, row in test_df.iterrows(): target_trial = row['target_trial'] if target_trial in self.trial_embs: target_emb = self.__getitem__(target_trial) else: raise ValueError(f'The embeddings of trial {target_trial} are not found.') candidate_embs = np.stack([self.__getitem__(tag) for tag in row.values[1:]]) sim = target_emb[None].dot(candidate_embs.T)[0] labels = label_df.iloc[idx].to_numpy() if labels.sum() == 0: continue ranked_label = labels[np.argsort(sim)[::-1]] ranked_label_list.append(ranked_label) ranked_label_list = np.array(ranked_label_list) return_dict = {} for k in [1,2,5]: return_dict[f'precision@{k}'] = precision(ranked_label_list, k) return_dict[f'recall@{k}'] = recall(ranked_label_list, k) return_dict[f'ndcg@{k}'] = ndcg(ranked_label_list, k) return return_dict
[docs] def save_model(self, output_dir): ''' Only save the embeddings. Do not save the model. Parameters ---------- output_dir: str The output directory to save the model. ''' if not os.path.exists(output_dir): make_dir_if_not_exist(output_dir) with open(os.path.join(output_dir, 'embeddings.pkl'), 'wb') as f: pickle.dump(self.trial_embs, f) print(f'Embeddings saved to {output_dir}')
[docs] def load_model(self, input_dir): ''' Only load the embeddings. Do not load the model. Parameters ---------- input_dir: str The input directory to load the model. ''' with open(os.path.join(input_dir, 'embeddings.pkl'), 'rb') as f: self.trial_embs = pickle.load(f) print(f'Embeddings loaded from {input_dir}')
[docs] def encode(self, inputs, batch_size=None, num_workers=None, return_dict=True, verbose=True, ): ''' Encode input documents and output the document embeddings. Parameters ---------- inputs: dict The input documents to encode: - 'fields' is the list of fields to be encoded. - 'tag' is the unique index column name of each document, e.g., 'nctid'. inputs = { 'x': pd.DataFrame, 'fields': list[str], 'tag': str, } batch_size: int, optional The batch size when encoding trials. num_workers: int, optional The number of workers when building the val dataloader. return_dict: bool Whether to return a dict of results. - If set True, return dict[np.ndarray]. - Else, return np.ndarray with the order same as the input documents. verbose: bool Whether plot progress bar or not. Returns ------- embs: dict[np.ndarray] Encoded trial-level embeddings with key (tag) and value (embedding).. ''' self._input_data_check(inputs) if batch_size is not None: self.config['batch_size'] = batch_size if num_workers is not None: self.config['num_worker'] = num_workers # build dataset and dataloader dataloader = self.get_val_dataloader(inputs) # encode all trials embs = self._encode_for_dataloader(dataloader, verbose=verbose, return_dict=return_dict) return embs
def get_val_dataloader(self, inputs): df = inputs['x'] self.config['fields'] = inputs['fields'] self.config['tag_field'] = inputs['tag'] dataset = self._build_val_dataset(df=df) collate_fn = self._build_val_collator() dataloader = self._build_val_dataloader(dataset=dataset, collate_fn=collate_fn) return dataloader def _build_model(self): self.model = BuildModel(self.bert_name) self.model.to(self.device) def _encode_for_dataloader(self, dataloader, verbose=True, return_dict=True): emb_list = [] tag_list = [] self.model.eval() with torch.no_grad(): for data in tqdm(dataloader, desc='Encoding', disable=not verbose): if self.config['tag_field'] is not None: tag = data.pop(self.config['tag_field']) tag_list.append(tag) data = self._prepare_input(data) # get an average of embeddings of all fields embs = [] for field in self.config['fields']: emb = self._encode_for_field(data=data, field=field) embs.append(emb) temp_emb = torch.stack(embs, dim=0).mean(dim=0) if len(temp_emb.shape) == 1: temp_emb = temp_emb[None] # whitening for each batch temp_emb = whitening_torch_final(temp_emb) emb_list.append(temp_emb.cpu().numpy()) embs = np.concatenate(emb_list, 0) if len(tag_list) == 0: tag_list = list(range(len(embs))) else: tag_list = sum(tag_list, []) if return_dict: embs = dict(zip(tag_list, embs)) return embs else: return (tag_list, embs) def _encode_for_field(self, data, field): inputs = {} for k, v in data.items(): if field in k: tempkey = k.replace(field+'_','') inputs[tempkey] = v inputs = batch_to_device(inputs, self.device) if self.layer_mode == 'last_first': # take a mean of the last and first layer inputs['return_hidden_states'] = True outputs = self.model(**inputs) emb = (outputs[1].mean(1) + outputs[-1].mean(1))/2 else: # only take the last layer embedding emb = self.model(**inputs) return emb def _prepare_input(self, inputs): fields = self.config['fields'] output_dict = {} for fd in fields: if fd not in inputs: continue for k,v in inputs[fd].items(): output_dict[f'{fd}_{k}'] = v return output_dict def _build_val_dataset(self, df): dataset = Dataset.from_pandas(df) dataset.set_transform(self._val_dataset_transform) return dataset def _build_val_collator(self): return TrialSearchCollator( bert_name=self.config['bert_name'], max_seq_length=self.config['max_seq_length'], fields=self.config['fields'], device=self.config['device'], tag_field=self.config['tag_field'], ) def _build_val_dataloader(self, dataset, collate_fn): dataloader = DataLoader( dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.config['num_workers'], collate_fn=collate_fn, pin_memory=True, ) return dataloader def _dataset_transform(self, examples): return_dict = {} if self.config['fields'] is not None: for k in self.config['fields']: return_dict[k] = examples[k] if self.config['tag_field'] is not None: return_dict[self.config['tag_field'] ] = examples[self.config['tag_field'] ] if len(return_dict) == 0: return_dict = examples return return_dict def _val_dataset_transform(self, examples): return self._dataset_transform(examples=examples) def _input_data_check(self, inputs): ''' Check the training / testing data fits the formats. Target to (1) check if inputs valid, if not, give tips about the data problem. Parameters ---------- inputs: { 'x': pd.DataFrame, 'fields': list[str], 'tag': str, } ''' # check overall input format assert 'x' in inputs, 'No input trial doc dataframe found in inputs.' df = inputs['x'] if 'fields' in inputs: try: _ = df[inputs['fields']] except: raise Exception('Cannot find the specified `fields` in inputs dataframe.') if 'tag' in inputs: try: _ = df[inputs['tag']] except: raise Exception('Cannot find the specified `tag` in inputs dataframe.') # check data type try: _ = df.applymap(str) except: raise Exception('Cannot transform the input dataframe to str type, please check the inputs.')