indiv_outcome.sequence.RAIM
- class pytrial.tasks.indiv_outcome.sequence.raim.RAIM(vocab_size, orders, mode=None, output_dim=None, max_visit=None, window_size=3, hidden_size=8, hidden_output_size=8, dropout=0.5, learning_rate=0.0001, weight_decay=0.0001, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test')[source]
Implment RAIM for longitudinal patient records predictive modeling 1.
- Parameters
vocab_size (list[int]) – A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication.
orders (list[str]) – A list of orders when treating inputs events. Should have the same shape of vocab_size.
mode (str) – Prediction traget in [‘binary’,’multiclass’,’multilabel’,’regression’].
output_dim (int) –
Output dimension of the model.
If binary classification, output_dim=1;
If multiclass/multilabel classification, output_dim=n_class;
If regression, output_dim=1.
max_visit (int) – The maximum number of visits for input event codes.
window_size (int) – The window size in the recurrent part of RAIM.
hidden_size (int) – Embedding size for encoding input event codes.
hidden_output_size (int) – The output size of the intermediate layers. Not the output_dim.
dropout (float) – Dropout rate in the intermediate layers.
learning_rate (float) – Learning rate for optimization based on SGD. Use torch.optim.Adam by default.
weight_decay (float) – Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization.
batch_size (int) – Batch size when doing SGD optimization.
epochs (int) – Maximum number of iterations taken for the solvers to converge.
num_worker (int) – Number of workers used to do dataloading during training.
device (str) – The model device.
experiment_id (str) – The prefix when saving the checkpoints during the training.
Notes
- 1
Xu, Y., Biswal, S., Deshpande, S. R., Maher, K. O., & Sun, J. (2018, July). Raim: Recurrent attentive and intensive model of multimodal patient monitoring data. In Proceedings of the 24th ACM SIGKDD international conference on Knowledge Discovery & Data Mining (pp. 2565-2573).
- fit(train_data, valid_data=None)[source]
Train model with sequential patient records.
- Parameters
train_data (SequencePatientBase) – A SequencePatientBase contains patient records where ‘v’ corresponds to visit sequence of different events; ‘y’ corresponds to labels.
valid_data (SequencePatientBase) – A SequencePatientBase contains patient records used to make early stopping of the model.
- load_model(checkpoint)[source]
Load pretrained model from the disk.
- Parameters
checkpoint (str) – The input directory that stores the trained pytorch model and configuration.