Source code for pytrial.tasks.trial_search.models.trial2vec

'''
Implement Trial2Vec model for clinical trial similarity search.
'''

import pdb
import os
import math
from collections import defaultdict
import random

from torch.cuda.amp import autocast
import pandas as pd
import numpy as np
import torch
import transformers
from torch.utils.data import DataLoader
from torch import nn
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from datasets import Dataset
from tqdm import tqdm

from pytrial.utils.trainer import Trainer
from pytrial.utils.check import install_package, make_dir_if_not_exist
from pytrial.utils.check import check_checkpoint_file, check_model_dir, check_model_config_file
from .base import TrialSearchBase
from ..losses import InfoNCELoss
from ..data import TrialSearchCollator, batch_to_device
from ..metrics import precision, recall, ndcg

PRETRAINED_TRIAL2VEC_URL = 'https://storage.googleapis.com/pytrial/trial2vec_pretrained.zip'

class BuildModel(nn.Module):
    config = {}
    def __init__(self,
        model_name,
        emb_dim,
        fields=None,
        ctx_fields=None,
        device=None,
        ) -> None:
        super().__init__()
        self.device = device
        self.base_encoder = AutoModel.from_pretrained(model_name)
        self.global_proj_head = nn.Linear(768, emb_dim, bias=False)
        self.local_proj_head = nn.Linear(768, emb_dim, bias=False)
        self.multihead_att = nn.MultiheadAttention(
            embed_dim=768,
            num_heads=6,
            batch_first=False, # [seq, batch, feature]
        )
        self.config['fields'] = fields
        self.config['ctx_fields'] = ctx_fields

    def forward(self, inputs, return_local_emb=True):
        '''
        (1) Go through each attr and ctx, encode them into embds.
        (2) Go through multihead att over the component embds.
        (3) Apply local and global proj head to get each component embds
        and global trial embds.

        Parameters
        ----------
        inputs: dict[dict]
            A dict of input components after tokenized.
            E.g., inputs['title'] are the tokenized title texts,
            w/ keys like input_ids, attention_mask, etc.

        Returns
        -------
        local_embs: dict[Tensor]
            A dict of all component embds.

        global_embs: Tensor
            The trial-level embeds.
        '''
        local_embs = {}

        fields = self.config['fields']
        ctx_fields = self.config['ctx_fields']

        attr_embs = self._encode_fields(fields, inputs, local_embs) # num_attr, bs, emb_dim

        ctx_embs = self._encode_fields(ctx_fields, inputs, local_embs) # num_ctx, bs, emb_dim

        if len(ctx_embs) > 0:
            # take a fusion of context and attribute embeddings
            temp_embs, att_score = self.multihead_att(ctx_embs, attr_embs, attr_embs) # num_ctx, bs, emb_dim
            temp_embs = torch.mean(temp_embs, 0)
        else:
            # no context embeddings, take directly the average of all local embs
            temp_embs = attr_embs.mean(0) # bs, emb_dim

        # projection and normalize
        global_embs = self.global_proj_head(temp_embs)
        # global_embs = global_embs / global_embs.norm(dim=-1, keepdim=True)

        return_dict = {
            'global_embs':global_embs,
        }

        if return_local_emb:
            return_dict['local_embs'] = local_embs

        return return_dict

    def _encode_fields(self, fields, inputs, local_embs):
        emb_list = []
        for fd in fields:
            input_dict = {}
            for k, v in inputs.items():
                if fd in k:
                    tempkey = k.replace(fd+'_','')
                    input_dict[tempkey] = v
            if len(input_dict) == 0:
                continue
            input_dict = batch_to_device(input_dict, self.device)
            res = self.base_encoder(**input_dict, return_dict=True)
            emb_list.append(res['pooler_output'])
            temp_emb = self.local_proj_head(res['pooler_output'])

            # normalize local embeddings
            # temp_emb = temp_emb / temp_emb.norm(dim=-1, keepdim=True)
            local_embs[fd] = temp_emb
        
        if len(emb_list) > 0:
            embs = torch.stack(emb_list)
            return embs
        else:
            return emb_list

class LocalMatchCollator(TrialSearchCollator):
    def __init__(self,
        bert_name,
        max_seq_length,
        fields,
        ctx_fields,
        tag_field=None,
        is_train=True,
        device='cuda:0',
        ) -> None:
        super().__init__(
            bert_name=bert_name,
            max_seq_length=max_seq_length,
            fields=fields,
            device=device,
            tag_field=tag_field,
        )
        self.fields = fields
        self.ctx_fields = ctx_fields
        self.tag_field = tag_field
        self.is_train = is_train
        self.device = device
        if is_train:
            print('Trigger training for Trial2Vec, will load `nltk` and `textaugment`.')
            install_package('nltk')
            install_package('textaugment')
            import nltk
            from textaugment import EDA
            nltk.download('stopwords')
            nltk.download('omw-1.4')
            nltk.download('wordnet')
            self.eda = EDA()

    def __call__(self, features):
        return_dict = defaultdict(list)
        batch_df = pd.DataFrame(features)
        batch_df.fillna('', inplace=True)

        fields = self.fields
        ctx_fields = self.ctx_fields

        if self.is_train:
            fields = self._random_sample(fields)
            ctx_fields = self._random_sample(ctx_fields)

        return_dict.update(self._batch_tokenize(batch_df=batch_df, fields=fields))
        return_dict.update(self._batch_tokenize(batch_df=batch_df, fields=ctx_fields))

        if self.tag_field is not None:
            return_dict[self.tag_field] = batch_df[self.tag_field].tolist()
        return return_dict

    def _batch_tokenize(self, batch_df, fields):
        return_dict = {}
        for field in fields:

            if self.is_train:
                texts = self._eda_augment(batch_df[field])
            else:
                texts = batch_df[field].tolist()

            tokenized = self.tokenizer(texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt')
            return_dict[field] = tokenized
        return return_dict

    def _eda_augment(self, text):
        eda_aug = random.choice([self.eda.synonym_replacement, self.eda.random_swap, self.eda.random_deletion])
        new_text = text.apply(lambda x: eda_aug(x) if len(x)>1 else x)
        new_text = new_text.apply(lambda x: x[0] if isinstance(x,list) else x).tolist()
        return new_text

    def _random_sample(self, fields, n=2):
        fields = np.array(fields).copy()
        if random.random() > 0.5:
            num_select = np.random.poisson(n, 1)[0]
            num_select = np.maximum(num_select, 1)
            sub_fields = fields[:num_select].tolist()
            return sub_fields
        else:
            return fields

class LocalMatchLoss(InfoNCELoss):
    '''
    Match local attribute embeddings and trial-level global embeddings.
    '''
    def __init__(self, model, fields, logit_scale_init_value):
        super().__init__(model=model, logit_scale_init_value=logit_scale_init_value)
        self.fields = fields

    def forward(self, inputs):
        outputs = self.model(inputs)
        embs = outputs['global_embs']
        fields = [f for f in self.fields if f in outputs['local_embs']]
        field = np.random.choice(fields,1)[0]
        local_embs = outputs['local_embs'][field]
        logits_per_trial = self.compute_logits(embs, local_embs)
        logits_per_attr = logits_per_trial.t()
        loss = (self.compute_loss(logits_per_trial) + self.compute_loss(logits_per_attr)) / 2
        return {'loss_value':loss}

class GlobalMatchCollator(LocalMatchCollator):
    def __init__(self,
        bert_name,
        max_seq_length,
        fields,
        ctx_fields,
        tag_field=None,
        is_train=True,
        device='cuda:0',
        ) -> None:
        super().__init__(
            bert_name=bert_name,
            max_seq_length=max_seq_length,
            fields=fields,
            ctx_fields=ctx_fields,
            tag_field=tag_field,
            is_train=is_train,
            device=device,
        )

    def __call__(self, features):
        return_dict = dict()
        batch_df = pd.DataFrame(features)
        batch_df.fillna('', inplace=True)

        fields = self.fields
        ctx_fields = self.ctx_fields

        # process to get anchor
        if self.is_train:
            fields = self._random_sample(fields)
            ctx_fields = self._random_sample(ctx_fields)
            
        if random.random() > 0.5: 
            ctx_fields = [] # only take local embeddings
            ctx_inputs = {}
        else:
            ctx_inputs = self._batch_tokenize(batch_df=batch_df, fields=ctx_fields)

        # anchor is attr+ctx
        return_dict['anchor'] = self._batch_tokenize(batch_df=batch_df, fields=fields)
        return_dict['anchor'].update(ctx_inputs)

        # positive is attr+ctx
        return_dict['pos'] = self._batch_tokenize(batch_df=batch_df, fields=fields)
        return_dict['pos'].update(ctx_inputs)

        # negative is attr'+ctx
        # process to get negative by random shuffling attr field texts
        sub_fields = self._random_sample(fields)
        batch_df[sub_fields] = batch_df[sub_fields].sample(frac=1).reset_index(drop=True)
        return_dict['neg'] = self._batch_tokenize(batch_df=batch_df, fields=fields)
        return_dict['neg'].update(ctx_inputs)

        if self.tag_field is not None:
            return_dict[self.tag_field] = batch_df[self.tag_field].tolist()
        return return_dict

class GlobalMatchLoss(InfoNCELoss):
    '''
    Match positive trials (by replacing trial components).
    '''
    def __init__(self, model, fields, logit_scale_init_value):
        super().__init__(model=model, logit_scale_init_value=logit_scale_init_value)
        self.fields = fields

    def forward(self, inputs):
        outputs_pos = self.model(inputs['pos'], return_local_emb=False)
        outputs_neg = self.model(inputs['neg'], return_local_emb=False)
        outputs_anc = self.model(inputs['anchor'], return_local_emb=False)
        temp_emb = torch.cat([outputs_pos['global_embs'], outputs_neg['global_embs']], dim=0)
        logits_per_trial = self.compute_logits(outputs_anc['global_embs'], temp_emb)
        loss = self.compute_loss(logits_per_trial)
        return {'loss_value': loss}

class Trial2VecTrainer(Trainer):
    '''
    Subclass the original trainer and provide specific evaluation functions.
    '''
    def get_test_dataloader(self, test_data):
        self.test_dataloader = test_data
        return self.test_dataloader

    def prepare_input(self, inputs):
        return self.model._prepare_input(inputs)

    def train_one_iteration(self, 
        max_grad_norm=None,
        warmup_steps=None,
        use_amp=None, 
        scaler=None,
        train_loss_dict=None):
        '''
        Default training one iteration steps, can be subclass can reimplemented.
        '''
        skip_scheduler = False
        num_train_objectives = len(self.train_dataloader)
        for train_idx in range(num_train_objectives):
            data_iterator = self.data_iterators[train_idx]
            loss_model = self.loss_models[train_idx]
            loss_model.zero_grad()
            loss_model.train()
            optimizer = self.optimizers[train_idx]
            scheduler = self.schedulers[train_idx]
            # print(scheduler.get_lr())

            try:
                data = next(data_iterator)
            except StopIteration:
                data_iterator = iter(self.train_dataloader[train_idx])
                self.data_iterators[train_idx] = data_iterator
                data = next(data_iterator)

            # usually pass raw tensors to the target device
            if 'anchor' in data:
                # global match loss
                new_data = {}
                for k,v in data.items():
                    new_data[k] = self.prepare_input(v)
                data = new_data
            else:
                # local match loss
                data = self.prepare_input(data)

            if use_amp:
                with autocast():
                    loss_return = loss_model(data)
                loss_value = loss_return['loss_value']
                scale_before_step = scaler.get_scale()
                scaler.scale(loss_value).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
                skip_scheduler = scaler.get_scale() != scale_before_step
            else:
                loss_model_return = loss_model(data)
                loss_value = loss_model_return['loss_value']
                loss_value.backward()
                torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
                optimizer.step()

            train_loss_dict[train_idx].append(loss_value.item())

            if use_amp:
                skip_scheduler = scaler.get_scale() != scale_before_step
                if not skip_scheduler:
                    scheduler.step()
            else:
                if warmup_steps > 0:
                    scheduler.step()

    def evaluate(self):
        # encode all trials in test data
        val_doc_embs = self.model.encode(self.model.val_doc_inputs, return_dict=True)
        self.model.update_emb(val_doc_embs)
        return_dict = self.model.evaluate(test_data={'x':self.test_dataloader['x'],'y':self.test_dataloader['y'],})
        return return_dict

[docs]class Trial2Vec(TrialSearchBase): ''' Implement the Trial2Vec model for trial document similarity search [1]_. Parameters ---------- fields: list[str] A list of fields of documents used as the `attribute` fields by Trial2Vec model. ctx_fields: list[str] A list of fields of documents used as the `context` fields by Trial2Vec model. tag_field: str The tag indicating trial documents, default to be 'nct_id'. bert_name: str (default='emilyalsentzer/Bio_ClinicalBERT') The base transformer-based encoder. Please find model names from the model hub of transformers (https://huggingface.co/models). emb_dim: int, optional (default=768) Dimensionality of the embedding vectors. logit_scale_init_value: float, optional (default=0.07) The logit scale or the temperature. max_seq_length: int (default=128) The maximum length of input tokens for the base encoder. epochs: int, optional (default=10) Number of iterations (epochs) over the corpus. batch_size: int, optional (default=64) Number of samples in each training batch. learning_rate: float, optional (default=3e-5) The learning rate. weight_decay: float, optional (default=1e-4) Weight decay applied for regularization. warmup_ratio: float (default=0) How many steps used for warmup training. If set 0, not warmup. evaluation_steps: int (default=10) How many iterations while we print the training loss and conduct evaluation if evaluator is given. num_workers: int, optional (default=0) Use these many worker threads to train the model (=faster training with multicore machines). device: str or torch.device (default='cuda:0') The device to put the model on. use_amp: bool (default=False) Whether or not use mixed precision training. experiment_id: str, optional (default='test') The name of current experiment. Notes ----- .. [1] Wang, Z., & Sun, J. (2022). Trial2Vec: Zero-Shot Clinical Trial Document Similarity Search using Self-Supervision. Findings of EMNLP 2022. ''' trial_embs = {} val_doc_inputs = None def __init__(self, fields=None, ctx_fields=None, tag_field='nct_id', bert_name='emilyalsentzer/Bio_ClinicalBERT', emb_dim=128, logit_scale_init_value=0.07, max_seq_length=128, epochs=10, batch_size=64, learning_rate=2e-5, weight_decay=1e-4, warmup_ratio=0, evaluation_steps=10, num_workers=0, device='cuda:0', use_amp=False, experiment_id='test' ) -> None: super().__init__(experiment_id=experiment_id) self.config = { 'max_seq_length':max_seq_length, 'logit_scale_init_value':logit_scale_init_value, 'bert_name':bert_name, 'emb_dim':emb_dim, 'epochs':epochs, 'batch_size':batch_size, 'learning_rate':learning_rate, 'weight_decay':weight_decay, 'evaluation_steps':evaluation_steps, 'num_workers':num_workers, 'device':device, 'use_amp':use_amp, 'warmup_ratio':warmup_ratio, 'tag_field':tag_field, 'fields':fields, 'ctx_fields':ctx_fields, } self.tokenizer = AutoTokenizer.from_pretrained(bert_name) if num_workers > 1: # disable tokenizer parallel when doing data collate parallel os.environ['TOKENIZERS_PARALLELISM'] = '0' self.use_amp = use_amp self.device = device if isinstance(device, str) else 'cuda:0' self._build_model()
[docs] def fit(self, train_data, valid_data=None): '''Train the trial2vec model to get document embeddings for trial search. Parameters ---------- train_data: dict train_data = { 'x': pd.DataFrame, 'fields': list[str], 'ctx_fields': list[str], 'tag': str, } Training corpus for the model. - `x`: a dataframe of trial documents. - `fields`: optional, the fields of documents to use for training as key attributes. If not given, the model uses all fields in `x`. - `ctx_fields`: optional, the fields of documents which belong to context components. If not given, the model will only learn from `fields`. - `tag`: optional, the field in `x` that serves as unique identifiers. Typically it is the `nct_id` of each trial. If not given, the model takes integer tags. valid_data: dict={'x':pd.DataFrame 'y':np.ndarray}. Validation data used for identifying the best checkpoint during the training. Need to rewrite the function: `get_val_dataloader`. ''' self._input_data_check(train_data) self._fit(train_data, val_data=valid_data)
[docs] def encode(self, inputs, batch_size=None, num_workers=None, return_dict=True, verbose=True, ): ''' Encode input documents and output the document embeddings. Parameters ---------- inputs: dict inputs = { 'x': pd.DataFrame, 'fields': list[str], 'ctx_fields': list[str], 'tag': str, } Share the same input format as the `train_data` in `fit` function. If `fields`, `ctx_fields`, `tag` are not given, will reuse the ones used during training. batch_size: int, optional The batch size when encoding trials. num_workers: int, optional The number of workers when building the val dataloader. return_dict: bool If set True, return dict[np.ndarray]. Else, return np.ndarray with the order same as the input documents. verbose: bool Whether plot progress bar or not. Returns ------- embs: dict[np.ndarray] Encoded trial-level embeddings with key (tag) and value (embedding).. ''' self._input_data_check(inputs) if batch_size is not None: self.config['batch_size'] = batch_size if num_workers is not None: self.config['num_worker'] = num_workers # build dataset and dataloader dataloader = self.get_val_dataloader(inputs) # i = iter(dataloader) # res = next(i) # go through dataloader and encode trial embds emb_list = [] tag_list = [] self.model.eval() with torch.no_grad(): for data in tqdm(dataloader, desc='Encoding', disable=not verbose): if self.config['tag_field'] is not None: tag = data.pop(self.config['tag_field']) tag_list.append(tag) data = self._prepare_input(data) outputs = self.model(data) temp_emb = outputs['global_embs'].cpu() temp_emb = temp_emb / temp_emb.norm(dim=-1, keepdim=True) temp_emb = temp_emb.numpy() if len(temp_emb.shape) == 1: temp_emb = temp_emb[None] emb_list.append(temp_emb) embs = np.concatenate(emb_list, 0) if len(tag_list) == 0: tag_list = list(range(len(embs))) else: tag_list = sum(tag_list, []) if return_dict: embs = dict(zip(tag_list, embs)) return embs else: return (tag_list, embs)
[docs] def predict(self, test_data, top_k=10, return_df=True, skip_pretrained=False, ): ''' Predict the top-k relevant for input documents. Parameters ---------- test_data: dict test_data = { 'x': pd.DataFrame, 'fields': list[str], 'ctx_fields': list[str], 'tag': str, } Share the same input format as the `train_data` in `fit` function. If `fields`, `ctx_fields`, `tag` are not given, will reuse the ones used during training. top_k: int Number of retrieved candidates. return_df: float Whether or not return dataframe for the computed similarity ranking. - If set True, return (rank, sim); - else, return rank_list=[[(doc1,sim1),(doc2,sim2)], [(doc1,sim1),...]]. skip_pretrained: bool Whether or not skip encoding the trial which has been in the self.trial_embs. If set True, will skip encoding the trial, and get the trial embeddings by lookup from self.trial_embs. Returns ------- rank: pd.DataFrame A dataframe contains the top ranked NCT ids for each. sim: pd.DataFrame A dataframe contains the corresponding similarities. rank_list: list[list[tuple]] A list of tuples of top ranked docs and similarities. ''' self._input_data_check(test_data) tag_field = self.config['tag_field'] all_embs = np.stack(self.trial_embs.values()) all_tags = np.stack(self.trial_embs.keys()) # skip encoding those already stored in model df = test_data['x'] # debug # df['nct_id'].iloc[0] = 'NCT00000001' # df['nct_id'].iloc[1] = 'NCT00000002' tags, embs = [], [] if skip_pretrained: tags, embs = self._encode_by_lookup(df) to_encode_test_trial = df[~df[tag_field].isin(all_tags)] else: to_encode_test_trial = df if len(to_encode_test_trial) > 0: # build inputs to_encode_test_data = {'x': to_encode_test_trial} enc_tags, enc_embs = self.encode(to_encode_test_data, return_dict=False) # reorder the trial embs to the same order as the input on the tag_field tags += enc_tags if len(embs) > 0: embs = np.concatenate([embs, enc_embs], 0) else: embs = enc_embs temp_df = pd.DataFrame({tag_field: tags, 'emb': list(embs)}) temp_df = pd.concat([df.set_index(tag_field), temp_df.set_index(tag_field)], 1).reset_index() embs = temp_df['emb'].values embs = np.stack(embs, 0) tags = temp_df[tag_field].values # rank for each test document sim = embs.dot(all_embs.T) rank = np.argsort(sim, 1)[:,::-1] # flip rank = rank[:,1:top_k+1] if return_df: tag_list, sim_list = [], [] for i,row in enumerate(rank): tag_list.append(all_tags[row]) sim_list.append(sim[i][row]) tag_ar = np.array(tag_list) sim_ar = np.array(sim_list) rank_df = pd.DataFrame(tag_ar, columns=[f'rank_{i}' for i in range(top_k)], index=tags) sim_df = pd.DataFrame(sim_ar, columns=[f'rank_{i}' for i in range(top_k)], index=tags) return rank_df, sim_df else: rank_list = [] for i,row in enumerate(rank): tag_ar = all_tags[row] sim_ar = sim[i][row] rank_list.append( list(zip(tag_ar, sim_ar)) ) return rank_list
[docs] def evaluate(self, test_data): ''' Evaluate within the given trial and corresponding candidate trials. Parameters ---------- test_data: dict test_data = { 'x': pd.DataFrame, 'y': pd.DataFrame } The provided labeled dataset for test trials. Follow the format listed above. Returns ------- results: dict[float] A dict of metrics and the values. Notes ----- x = | target_trial | trial1 | trial2 | trial3 | | nct01 | nct02 | nct03 | nct04 | y = | label1 | label2 | label3 | | 0 | 0 | 1 | ''' test_df = test_data['x'] label_df = test_data['y'] ranked_label_list = [] for idx, row in test_df.iterrows(): target_trial = row['target_trial'] if target_trial in self.trial_embs: target_emb = self.__getitem__(target_trial) else: raise ValueError(f'The embeddings of trial {target_trial} are not found.') candidate_embs = np.stack([self.__getitem__(tag) for tag in row.values[1:]]) sim = target_emb[None].dot(candidate_embs.T)[0] labels = label_df.iloc[idx].to_numpy() if labels.sum() == 0: continue ranked_label = labels[np.argsort(sim)[::-1]] ranked_label_list.append(ranked_label) ranked_label_list = np.array(ranked_label_list) return_dict = {} for k in [1,2,5]: return_dict[f'precision@{k}'] = precision(ranked_label_list, k) return_dict[f'recall@{k}'] = recall(ranked_label_list, k) return_dict[f'ndcg@{k}'] = ndcg(ranked_label_list, k) return return_dict
[docs] def load_model(self, checkpoint): ''' Load model and the pre-encoded trial embeddings from the given checkpoint dir. Parameters ---------- checkpoint: str The input dir that stores the pretrained model. If a directory, the only checkpoint file `*.pth.tar` will be loaded. If a filepath, will load from this file. ''' checkpoint_filename = check_checkpoint_file(checkpoint) config_filename = check_model_config_file(checkpoint) state_dict = torch.load(checkpoint_filename) if config_filename is not None: config = self._load_model_config(config_filename) self.config.update(config) self.model.config.update({'fields':config['fields'], 'ctx_fields':config['ctx_fields']}) self.model.load_state_dict(state_dict['model']) self.trial_embs = state_dict['emb']
[docs] def save_model(self, output_dir): make_dir_if_not_exist(output_dir) self._save_model_config(model_config=self.config, output_dir=output_dir) model = self._unwrap_model(self.model) self._save_checkpoint( {'model':model.state_dict(),'emb':self.trial_embs}, output_dir=output_dir)
def get_train_dataloader(self, inputs): df = inputs['x'] dataset_list = self._build_dataset(df=df) collator_list = self._build_collator() dataloader_list = self._build_dataloader(dataset_list, collator_list) return dataloader_list def get_val_dataloader(self, inputs): df = inputs['x'] dataset = self._build_val_dataset(df=df) collate_fn = self._build_val_collator() dataloader = self._build_val_dataloader(dataset=dataset, collate_fn=collate_fn) return dataloader
[docs] def update_emb(self, emb_dict): ''' Update trial embeds: add or modify. Parameters ---------- emb_dict: dict[np.ndarray] The tag and corresponding trial embeddings to updated. ''' self.trial_embs.update( emb_dict )
[docs] def from_pretrained(self, input_dir=None): ''' Download pretrained Trial2Vec model. ''' if input_dir is None or not os.path.exists(input_dir): if input_dir is None: input_dir = './trial_search/pretrained_trial2vec' os.makedirs(input_dir) print(f'Download pretrained Trial2Vec model, save to {input_dir}.') self._download_pretrained(output_dir=input_dir) print('Load pretrained Trial2Vec model from', input_dir) self.load_model(input_dir)
def _prepare_input(self, inputs): ctx_fields = self.config['ctx_fields'] fields = self.config['fields'] output_dict = {} for fd in ctx_fields: if fd not in inputs: continue for k,v in inputs[fd].items(): output_dict[f'{fd}_{k}'] = v for fd in fields: if fd not in inputs: continue for k,v in inputs[fd].items(): output_dict[f'{fd}_{k}'] = v return output_dict def __getitem__(self, tag): return self.trial_embs[tag] def _build_model(self): model = BuildModel( model_name=self.config['bert_name'], emb_dim=self.config['emb_dim'], ctx_fields=self.config['ctx_fields'], fields=self.config['fields'], device=self.device, ) self.model = self._wrap_model(model, self.config['device']) def _fit(self, train_data, val_data=None): # build dataset if 'fields' in train_data: self.config['fields']= train_data['fields'] if 'ctx_fields' in train_data: self.config['ctx_fields'] = train_data['ctx_fields'] if 'tag' in train_data: self.config['tag_field'] = train_data['tag'] self.model.config.update({'fields': self.config['fields'], 'ctx_fields':self.config['ctx_fields']}) if val_data is not None: self._build_val_docs(train_data, val_data=val_data) loss_models = self._build_loss_model() # build train dataloader list dataloader_list = self.get_train_dataloader(train_data) # kick off training train_objectives = list(zip(dataloader_list, loss_models)) trainer = Trial2VecTrainer( model=self, train_objectives=train_objectives, test_data=val_data, test_metric='ndcg@5', ) trainer.train( **self.config, ) # encode all training trial docs after training self.update_emb(self.encode(train_data, return_dict=True)) def _build_collator(self, is_train=True): collator_list = [ LocalMatchCollator( bert_name=self.config['bert_name'], max_seq_length=self.config['max_seq_length'], fields=self.config['fields'], ctx_fields=self.config['ctx_fields'], device=self.config['device'], is_train=is_train, tag_field=None, ), GlobalMatchCollator( bert_name=self.config['bert_name'], max_seq_length=self.config['max_seq_length'], fields=self.config['fields'], ctx_fields=self.config['ctx_fields'], device=self.config['device'], is_train=is_train, tag_field=None, ), ] return collator_list def _build_dataloader(self, dataset_list, collator_list): '''Build dataloaders for multiple training supervision of Trial2Vec. (1) query to the whole trial (w/ the query) search (2) replace query w/ others ''' dataloader_list = [] for i,dataset in enumerate(dataset_list): dataloader = DataLoader( dataset, batch_size=self.config['batch_size'], shuffle=True, num_workers=self.config['num_workers'], collate_fn=collator_list[i], pin_memory=True, ) dataloader_list.append(dataloader) return dataloader_list def _build_dataset(self, df): dataset_list = [] # build local contrastive (match indiv component to global trial embs) dataset = Dataset.from_pandas(df) dataset.set_transform(self._dataset_transform) dataset_list.append(dataset) # build global contrastive (discriminate global trial emb when replacing attrs) dataset = Dataset.from_pandas(df) dataset.set_transform(self._dataset_transform) dataset_list.append(dataset) return dataset_list def _dataset_transform(self, examples): return_dict = {} if self.config['fields'] is not None: for k in self.config['fields']: return_dict[k] = examples[k] if self.config['ctx_fields'] is not None: for k in self.config['ctx_fields']: return_dict[k] = examples[k] if self.config['tag_field'] is not None: return_dict[self.config['tag_field'] ] = examples[self.config['tag_field'] ] if len(return_dict) == 0: return_dict = examples return return_dict def _build_loss_model(self): ''' Build two loss models for training trial2vec from scratch. ''' loss_models = [ LocalMatchLoss(model=self.model, fields=self.config['fields'], logit_scale_init_value=self.config['logit_scale_init_value']), GlobalMatchLoss(model=self.model, fields=self.config['fields'], logit_scale_init_value=self.config['logit_scale_init_value']), ] return loss_models def _val_dataset_transform(self, examples): return self._dataset_transform(examples=examples) def _build_val_dataset(self, df): dataset = Dataset.from_pandas(df) dataset.set_transform(self._val_dataset_transform) return dataset def _build_val_dataloader(self, dataset, collate_fn): dataloader = DataLoader( dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.config['num_workers'], collate_fn=collate_fn, pin_memory=True, ) return dataloader def _build_val_collator(self): return LocalMatchCollator( bert_name=self.config['bert_name'], max_seq_length=self.config['max_seq_length'], fields=self.config['fields'], ctx_fields=self.config['ctx_fields'], device=self.config['device'], is_train=False, tag_field=self.config['tag_field'], ) def _build_val_docs(self, train_data, val_data): # extract trial docs from val data target trials and candidate trials df_tr = train_data['x'] df_va = val_data['x'] tag_field = self.config['tag_field'] tag_field = tag_field if tag_field is not None else 'tag' df_va = pd.DataFrame({tag_field: df_va.to_numpy().flatten()}).drop_duplicates().reset_index(drop=True) df_va = df_va.merge(df_tr, on=tag_field, how='inner') self.val_doc_inputs = { 'x':df_va, } def _encode_by_lookup(self, df): # get embs and tags by lookup tag_field = self.config['tag_field'] tags, embs = [], [] for tag in df[tag_field]: if tag in self.trial_embs: tags.append(tag) embs.append(self.trial_embs[tag]) embs = np.array(embs) return tags, embs def _input_data_check(self, inputs): ''' Check the training / testing data fits the formats. Target to (1) check if inputs valid, if not, give tips about the data problem. Parameters ---------- inputs: { 'x': pd.DataFrame, 'fields': list[str], 'ctx_fields': list[str], 'tag': str, } ''' # check overall input format assert 'x' in inputs, 'No input trial doc dataframe found in inputs.' df = inputs['x'] if 'fields' in inputs: try: _ = df[inputs['fields']] except: raise Exception('Cannot find the specified `fields` in inputs dataframe.') if 'ctx_fields' in inputs: try: _ = df[inputs['ctx_fields']] except: raise Exception('Cannot find the specified `ctx_fields` in inputs dataframe.') if 'tag' in inputs: try: _ = df[inputs['tag']] except: raise Exception('Cannot find the specified `tag` in inputs dataframe.') # check data type try: _ = df.applymap(str) except: raise Exception('Cannot transform the input dataframe to str type, please check the inputs.') def _wrap_model(self, model, device): if isinstance(device, list): model = nn.DataParallel(model, device_ids=device) model.to(f'cuda:{model.device_ids[0]}') elif device == 'cpu': model.to(torch.device('cpu')) else: model.to(torch.device('cuda')) return model def _unwrap_model(self, model): if isinstance(model, nn.DataParallel): return model.module else: return model def _download_pretrained(self, output_dir): import wget import zipfile filename = wget.download(url=PRETRAINED_TRIAL2VEC_URL, out=output_dir) zipf = zipfile.ZipFile(filename, 'r') zipf.extractall(output_dir) zipf.close()