import warnings
from collections import defaultdict
import pdb
import os
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from ..utils.trial_utils import ClinicalTrials
from ..utils.tabular_utils import read_csv_to_df
from ..data.vocab_data import Vocab
from ..utils.tabular_utils import HyperTransformer
[docs]class TrialDatasetBase(Dataset):
'''
The basic trial datasets loader.
Parameters
----------
data: pd.DataFrame
Contain the trial document in tabular format.
criteria_column: str
The column name of eligibility criteria in the dataframe.
'''
inc_ec_embedding = None # inclusion criteria embedding
inc_vocab = None # inclusion criteria vocab
exc_ec_embedding = None # exclusion criteria embedding
exc_vocab = None # exclusion criteria vocab
def __init__(self, data, criteria_column='criteria'):
self.df = data
self._process_ec(criteria_column=criteria_column)
self._collect_cleaned_sentence_set()
def __len__(self):
return len(self.df)
def __getitem__(self, index):
return self.df.iloc[index:index+1]
[docs] def get_ec_sentence_embedding(self):
'''
Process the eligibility criteria of each trial,
get the criterion-level emebddings stored in dict.
Parameters
----------
criteria_column: str
The column name of eligibility criteria in the dataframe.
'''
if self.inc_ec_embedding is None or self.exc_ec_embedding is None:
self._get_ec_emb()
return self.inc_ec_embedding, self.exc_ec_embedding
def _process_ec(self, criteria_column):
res = self.df[criteria_column].apply(lambda x: self._split_protocol(x))
self.df['inclusion_criteria'] = res.apply(lambda x: x[0])
self.df['exclusion_criteria'] = res.apply(lambda x: x[1])
def _get_ec_emb(self):
# create EC embedding with indexed ECs
from pytrial.model_utils.bert import BERT
# check if cuda is available
if torch.cuda.is_available():
device = 'cuda:0'
else:
device = 'cpu'
bert_model = BERT(device=device)
self.inc_ec_embedding = bert_model.encode(self.inc_vocab.words, batch_size=64)
self.exc_ec_embedding = bert_model.encode(self.exc_vocab.words, batch_size=64)
self.inc_ec_embedding = self.inc_ec_embedding.cpu()
self.exc_ec_embedding = self.exc_ec_embedding.cpu()
def _collect_cleaned_sentence_set(self):
# create a vocab for ec sentences
self.inc_vocab = Vocab()
self.exc_vocab = Vocab()
self.inc_vocab.add_sentence(['[PAD]']) # 0 belongs to the pad token
self.exc_vocab.add_sentence(['[PAD]']) # 0 belongs to the pad token
inc_index_set, exc_index_set = [], []
for idx, row in self.df.iterrows():
row_inc_set, row_exc_set = [], []
inc = row['inclusion_criteria']
exc = row['exclusion_criteria']
for sent in inc:
self.inc_vocab.add_sentence(sent)
row_inc_set.append(self.inc_vocab.word2idx[sent])
for sent in exc:
self.exc_vocab.add_sentence(sent)
row_exc_set.append(self.exc_vocab.word2idx[sent])
inc_index_set.append(list(set(row_inc_set)))
exc_index_set.append(list(set(row_exc_set)))
self.df['inclusion_criteria_index'] = inc_index_set
self.df['exclusion_criteria_index'] = exc_index_set
def _clean_protocol(self, protocol):
protocol = protocol.lower()
protocol_split = protocol.split('\n')
filter_out_empty_fn = lambda x: len(x.strip())>0
strip_fn = lambda x: x.strip()
protocol_split = list(filter(filter_out_empty_fn, protocol_split))
protocol_split = list(map(strip_fn, protocol_split))
return protocol_split
def _split_protocol(self, protocol):
protocol_split = self._clean_protocol(protocol)
inclusion_idx, exclusion_idx = len(protocol_split), len(protocol_split)
for idx, sentence in enumerate(protocol_split):
if "inclusion" in sentence:
inclusion_idx = idx
break
for idx, sentence in enumerate(protocol_split):
if "exclusion" in sentence:
exclusion_idx = idx
break
if inclusion_idx + 1 < exclusion_idx + 1 < len(protocol_split):
inclusion_criteria = protocol_split[inclusion_idx:exclusion_idx]
exclusion_criteria = protocol_split[exclusion_idx:]
if not (len(inclusion_criteria) > 0 and len(exclusion_criteria) > 0):
print(len(inclusion_criteria), len(exclusion_criteria), len(protocol_split))
exit()
return inclusion_criteria, exclusion_criteria ## list, list
else:
return protocol_split, []
class TrialDataset(Dataset):
'''
Basic trial datasets loader.
Parameters
----------
input_dir: str
The path to the trial dataset in tabular form (.csv).
If a directory is given, the code will automatically pick the only '.csv' file under this dir.
'''
def __init__(self, input_dir=None) -> None:
if os.path.isfile(input_dir):
self.df = read_csv_to_df(input_dir, index_col=0)
if os.path.isdir(input_dir):
csv_names = [name for name in os.listdir(input_dir) if name.endswith('.csv')]
if len(csv_names) > 1:
raise Exception(f'`input_dir` {input_dir} is given where more than one csv files are found under this path.')
if len(csv_names) == 0:
raise Exception(f'`input_dir` {input_dir} is given where no csv file is found under this path.')
self.df = read_csv_to_df(os.path.join(input_dir, csv_names[0]), index_col=0)
def __len__(self):
return len(self.df)
def __getitem__(self, index):
return self.df.iloc[index:index+1]
class TrialDataCollator:
'''The basic trial data collator.
Subclass it and override the `__init__` & `__call__` function if need operations inside this step.
Returns
-------
batch_df: pd.DataFrame
A dataframe contains multiple fields for each trial.
'''
def __init__(self) -> None:
# subclass to add tokenizer
# subclass to add feature preprocessor
pass
def __call__(self, examples):
batch_df = pd.concat(examples, 0)
batch_df.fillna('none',inplace=True)
return batch_df
[docs]class TrialOutcomeDatasetBase(TrialDatasetBase):
'''
Basic trial outcome datasets loader.
Parameters
----------
data: pd.DataFrame
Contain the trial document in tabular format.
'''
columns = ['nctid', 'label', 'smiless', 'icdcodes', 'criteria']
def __init__(self, data, columns=None) -> None:
self.data = data
if columns is not None:
self.columns = columns
def __len__(self):
return len(self.data)
def __getitem__(self, index):
row = self.data.iloc[index]
return row[self.columns[0]], row[self.columns[1]], row[self.columns[2]], row[self.columns[3]], row[self.columns[4]]
class TrialDatasetStructured(Dataset):
'''
Dataset class for structured trial features. Subclass it if additional properties and functions
are required to add for specific tasks. We make use `rdt`: https://docs.sdv.dev/rdt for transform
and reverse transform of the tabular data.
Parameters
----------
df: pd.DataFrame
The input trial tabular format records.
metadata: dict
Contains the meta setups of the input data. It should contain the following keys:
(1) `sdtypes`: dict, the data types of each column in the input data. The keys are the column
names and the values are the data types. The data types can be one of the following:
'numerical', 'categorical', 'datetime', 'boolean'.
(2) `transformers`: dict, the transformers to be used for each column. The keys are the column
names and the values are the transformer names. The transformer names can be one in
https://docs.sdv.dev/rdt/transformers-glossary/browse-transformers.
metadata = {
'sdtypes': {
'column1': 'numerical',
'column2': 'boolean',
'column3': 'datetime',
'column4': 'categorical',
'column5': 'categorical',
},
transformers':{
'column1': rdt.transformers.FloatFormatter(missing_value_replacement='mean'),
'column2': rdt.transformers.BinaryEncoder(missing_value_replacement='mode'),
'column3': rdt.transformers.UnixTimestampEncoder(missing_value_replacement='mean'),
'column4': rdt.transformers.FrequencyEncoder(),
'column5': None, # will not do any thing to this column if no transformer is specified.
}
}
It is recommend to provide the metadata of the input tabular data.
- If no metdata is given, the dataset will automatically detect the `dtypes` of columns and build the corresponding `transformers`.
- If only `sdtypes` are given, the dataset will detect if there are missing `sdtypes` given and build the `transformers` and dtype automatically.
transform: bool(default=True)
Whether or not transform raw self.df by hypertransformer.
If set False, :code:`self.df` will keep as the same as the passed one.
Examples
--------
>>> from pytrial.data.site_data import TabularSiteBase
>>> df = pd.read_csv('tabular_trial.csv', index_col=0)
>>> # set `transform=True` will replace dataset.df with dataset.df_transformed
>>> dataset = TabularPatientBase(df, transform=True)
>>> # transform raw dataframe to numerical tables
>>> df_transformed = dataset.transform(df)
>>> # make back transform to the original df
>>> df_raw = dataset.reverse_transform(df_transformed)
'''
def __init__(self, df, metadata=None, transform=True):
self.df = df
self.metadata = metadata
# initialize hypertransformer
self.ht = HyperTransformer()
if transform:
if metadata is None:
warnings.warn('No metadata provided. Metadata will be automatically '
'detected from your data. This process may not be accurate. '
'We recommend writing metadata to ensure correct data handling.')
self.ht.detect_initial_config(df)
self.metadata = self.ht.get_config()
self.ht.fit(df)
else:
# parse the metadata and update hypertransformer's config
self._parse_metadata()
# replace data with the transformed one
self.df = self.transform(df)
def __getitem__(self, index):
# TODO: support better indexing
'''
Indexing the dataframe stored in tabular patient dataset.
Parameters
----------
index: int or list[int]
Retrieve the corresponding rows in the dataset.
'''
if isinstance(index, int):
return self.df.iloc[index:index+1]
elif isinstance(index, list):
return self.df.iloc[index]
def __len__(self):
return len(self.df)
def __repr__(self):
return f'<pytrial.data.trial_data.TrialDatasetStructured object> Tabular trial data with {self.df.shape[0]} samples {self.df.shape[1]} features, call `.df` to yield the pd.DataFrame data: \n' + repr(self.df)
def transform(self, df=None):
'''
Transform the input df or the self.df by hypertransformer.
If transform=True in `__init__`, then you do not need to call this function
to transform self.df because it was tranformed already.
Parameters
----------
df: pd.DataFrame
The dataframe to be transformed by self.ht
'''
if df is None:
return self.ht.transform(self.df)
else:
return self.ht.transform(df)
def reverse_transform(self, df=None):
'''
Reverse the input dataframe back to the original format. Return the self.df in the original
format if `df=None`.
Parameters
----------
df: pd.DataFrame
The dataframe to be transformed back to the original format by self.ht.
'''
if df is None:
return self.ht.reverse_transform(self.df)
else:
return self.ht.reverse_transform(df)
def _parse_metadata(self):
'''
Parse the passed metadata, cope with the following scnearios:
(1) only `sdtypes` are given;
(2) only `transformers` are given;
(3) only partial `sdtypes` are given;
(4) only partial `transformers` are given.
'''
# parse metadata dict for building the hypertransformer
metadata = self.metadata
self.ht.detect_initial_config(self.df, verbose=False)
if 'transformers' in metadata:
self.ht.update_transformers(metadata['transformers'])
if 'sdtypes' in metadata:
self.ht.update_sdtypes(metadata['sdtypes'])
self.ht.fit(self.df)
self.metadata.update(self.ht.get_config())
def test():
trialdata = TrialDataset('./datasets/AACT-ClinicalTrial/')
trial_collate_fn = TrialDataCollator()
trialoader = DataLoader(trialdata, batch_size=10, shuffle=False, collate_fn=trial_collate_fn)
batch = next(iter(trialoader))
print(batch)
if __name__ == '__main__':
test()