Source code for pytrial.tasks.indiv_outcome.tabular.mlp

import pdb
import os

import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader

from pytrial.data.patient_data import TabularPatientBase
from pytrial.utils.check import (
    check_checkpoint_file, check_model_dir, check_model_config_file, make_dir_if_not_exist
)

from .base import TabularIndivBase, IndivTabDataset

class BuildModel(nn.Module):
    def __init__(self,
        input_dim,
        output_dim,
        hidden_dim=128,
        num_layer=2,
        ) -> None:
        super().__init__()
        if num_layer == 1:
            self.mlp = nn.ModuleList([nn.Linear(input_dim, output_dim)])
        else:
            self.mlp = nn.ModuleList([nn.Linear(input_dim, hidden_dim)])

        for _ in range(num_layer-2):
            self.mlp.append(nn.ReLU())
            self.mlp.append(nn.Linear(hidden_dim, hidden_dim))
        
        if num_layer > 1:
            self.mlp.append(nn.ReLU())
            self.mlp.append(nn.Linear(hidden_dim, output_dim))
    
    def forward(self, inputs):
        if isinstance(inputs, dict):
            h = inputs['x']
        elif isinstance(inputs, torch.Tensor):
            h = inputs
        else:
            raise ValueError('inputs should be dict or torch.Tensor')

        h = h.float()
        for layer in self.mlp:
            h = layer(h)
        return h

[docs]class MLP(TabularIndivBase): ''' Implement multi-layer perceptron model for tabular individual outcome prediction in clinical trials. Parameters ---------- input_dim: int Dimension of the input features. output_dim: int Dimension of the outputs. When doing classification, it equals to number of classes. mode: str The task's objectives, in `binary`, `multiclass`, `multilabel`, or `regression` hidden_dim: int Hidden dimensions of neural networks. num_layer: int Number of hidden layers. learning_rate: float Learning rate for optimization based on SGD. Use torch.optim.Adam by default. weigth_decay: float Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization. batch_size: int Batch size when doing SGD optimization. epochs: int Maximum number of iterations taken for the solvers to converge. num_worker: int Number of workers used to do dataloading during training. device: str Target device to train the model, as `cuda:0` or `cpu`. experiment_id: str, optional (default='test') The name of current experiment. Decide the saved model checkpoint name. ''' def __init__(self, input_dim, output_dim, mode, hidden_dim=128, num_layer=2, learning_rate=1e-4, weight_decay=1e-4, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test'): super().__init__(experiment_id) mode = mode.lower() assert mode in ['binary', 'multiclass', 'regression', 'multilabel'] self.config = { 'input_dim':input_dim, 'output_dim':output_dim, 'hidden_dim':hidden_dim, 'num_layer':num_layer, 'learning_rate':learning_rate, 'batch_size':batch_size, 'weight_decay':weight_decay, 'epochs':epochs, 'num_worker':num_worker, 'experiment_id':experiment_id, 'model_name': 'MLP', 'device':device, 'mode':mode, } self._save_config(self.config) self.device = device
[docs] def fit(self, train_data, valid_data=None): '''Train logistic regression model to predict patient outcome with tabular input data. Parameters ---------- train_data: dict { 'x': TabularPatientBase or pd.DataFrame, 'y': pd.Series or np.ndarray } - 'x' contain all patient features; - 'y' contain labels for each row. valid_data: same as `train_data`. Validation data during the training for early stopping. ''' self._input_data_check(train_data) self._build_model() x_feat, y = self._parse_input_data(train_data) train_data={'x':x_feat, 'y':y} if valid_data is not None: x_feat_va, y_va = self._parse_input_data(valid_data) valid_data = {'x':x_feat_va, 'y':y_va} self._fit_model(train_data=train_data, valid_data=valid_data)
[docs] def predict(self, test_data): ''' Make prediction probability based on the learned model. Parameters ---------- test_data: Dict or TabularPatientBase or pd.DataFrame or torch.Tensor {'x': TabularPatientBase or pd.DataFrame} 'x' contain all patient features. Returns ------- ypred: np.ndarray or torch.Tensor Prediction probability for each patient. - For binary classification, return shape (n, ); - For multiclass classification, return shape (n, n_class). ''' if isinstance(test_data, torch.Tensor): return self.model(test_data) x_feat, y = self._parse_input_data(test_data) test_data = {'x':x_feat, 'y': y} dataset = self._build_dataset(test_data) dataloader = DataLoader(dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.config['num_worker'], pin_memory=True, ) ypred = self._predict_on_dataloader(dataloader) return ypred
[docs] def save_model(self, output_dir=None): ''' Save the learned logistic regression model to the disk. Parameters ---------- output_dir: str or None The dir to save the learned model. If set None, will save model to `self.checkout_dir`. ''' if output_dir is not None: make_dir_if_not_exist(output_dir) else: output_dir = self.checkout_dir self._save_config(self.config, output_dir=output_dir) self._save_checkpoint({'model':self.model}, output_dir=output_dir)
[docs] def load_model(self, checkpoint): ''' Load model and the pre-encoded trial embeddings from the given checkpoint dir. Parameters ---------- checkpoint: str The input dir that stores the pretrained model. - If a directory, the only checkpoint file `*.pth.tar` will be loaded. - If a filepath, will load from this file. ''' checkpoint_filename = check_checkpoint_file(checkpoint) config_filename = check_model_config_file(checkpoint) state_dict = torch.load(checkpoint_filename) if config_filename is not None: config = self._load_config(config_filename) self.config.update(config) self.model = state_dict['model']
def _build_model(self): self.model = BuildModel( input_dim=self.config['input_dim'], output_dim=self.config['output_dim'], hidden_dim=self.config['hidden_dim'], ) self.model.to(self.config['device'])