trial_outcome.SPOT
- class pytrial.tasks.trial_outcome.spot.SPOT(num_topics=50, n_trial_projector=2, n_timestemp_projector=2, n_rnn_layer=1, criteria_column='criteria', batch_size=1, n_trial_per_batch=None, learning_rate=0.0001, weight_decay=0.0001, epochs=10, evaluation_steps=50, warmup_ratio=0, device='cuda:0', seed=42, output_dir='./checkpoints/spot')[source]
Implement sequential predictive modeling of clinical trial outcome with meta-learning (SPOT) 1.
- Parameters
num_topics (int) – Number of topics to discover.
n_trial_projector (int) – Number of layers in the trial projector.
n_timestemp_projector (int) – Number of layers in the timestamp projector.
n_rnn_layer (int) – Number of layers in the RNN for encoding ontology & smiles strings.
criteria_column (str) – The column name of the criteria in the input data.
batch_size (int) – Batch size for training.
n_trial_per_batch (int) – Number of trials in each batch. If None, use all trials in the topic in each batch.
learning_rate (float) – Learning rate for training.
weight_decay (float) – Regularization strength for l2 norm; must be a positive float.
epochs (int) – Number of training epochs.
evaluation_steps (int) – Number of steps to evaluate the model.
warmup_ratio (float) – Warmup ratio for learning rate scheduler.
device (str) – Device to use for training and inference.
seed (int) – Random seed.
Notes
- 1
Wang, Z., Xiao, C., & Sun, J. (2023). SPOT: Sequential Predictive Modeling of Clinical Trial Outcome with Meta-Learning. arXiv preprint arXiv:2304.05352.
- add_trials(seqtrain, df)[source]
Add new trials to the SequenceTrial object. It is used for the validation and test data.
- Parameters
seqtrain (model_utils.data_structure.SequenceTrial) – The sequence of clinical trials represented by SequenceTrial object.
df (pd.DataFrame) – The dataframe of new trials that should be added to the input seqtrain.
- fit(train_data, valid_data=None)[source]
Train the SPOT model. It has two/three steps: 1. Encode the criteria using pretrained BERT to get the criteria embedding (optional). 2. Discover topics from the training data and transform the data into a SequenceTrial object. 3. Train the SPOT model.
- Parameters
train_data (TrialOutcomeDataset) – Training data, should be a TrialOutcomeDataset object.
valid_data (TrialOutcomeDataset) – Validation data, should be a TrialOutcomeDataset object.
- load_model(input_dir=None)[source]
Load the model from a directory.
- Parameters
input_dir (str or None) – The directory to load the model.
- predict(data, target_trials=None)[source]
Predict the outcome of a clinical trial.
- Parameters
data (TrialOutcomeDataset) – A TrialOutcomeDataset object containing the data to predict.
target_trials (list) – A list of trial ids to predict. If None, all trials in data will be predicted.
- save_model(output_dir=None)[source]
Save the model to a directory.
- Parameters
output_dir (str or None) – The directory to save the model. If None, use the default directory ‘./checkpoints’.
- transform_topic(df, df_val=None, df_test=None)[source]
Transform the data into a SequenceTrial object.
- Parameters
df (pandas.DataFrame) – A dataframe containing the training data.
df_val (pandas.DataFrame) – A dataframe containing the validation data.
df_test (pandas.DataFrame) – A dataframe containing the test data.