indiv_outcome.sequence.Dipole

class pytrial.tasks.indiv_outcome.sequence.dipole.Dipole(vocab_size, orders, mode, output_dim=None, max_visit=None, attention_type='location_based', attention_dim=8, emb_size=16, hidden_size=8, hidden_output_size=8, dropout=0.5, learning_rate=0.0001, weight_decay=0.0001, batch_size=64, epochs=10, num_worker=0, device='cuda:0', experiment_id='test')[source]

Implement Dipole for longitudinal patient records predictive modeling 1.

Parameters
  • vocab_size (list[int]) – A list of vocabulary size for different types of events, e.g., for diagnosis, procedure, medication.

  • orders (list[str]) – A list of orders when treating inputs events. Should have the same shape of vocab_size.

  • mode (str) – Prediction traget in [‘binary’,’multiclass’,’multilabel’,’regression’].

  • output_dim (int) –

    Output dimension of the model.

    • If binary classification, output_dim=1;

    • If multiclass/multilabel classification, output_dim=n_class

    • If regression, output_dim=1.

  • max_visit (int) – The maximum number of visits for input event codes.

  • attention_type ({'general', 'concatenation_based', 'location_based'}) –

    Apply attention mechnism to derive a context vector that captures relevant information to help predict target.

    • ’location_based’: Location-based Attention. Alocation-based attention function is to calculate the weights solely from hidden state

    • ’general’: General Attention. An easy way to capture the relationship between two hidden states

    • ’concatenation_based’: Concatenation-based Attention. Via concatenating two hidden states, then use multi-layer perceptron(MLP) to calculatethe contextvector

  • attention_dim (int) – It is the latent dimensionality used for attention weight computing just for for concatenation_based attention mechnism

  • emb_size (int) – Embedding size for encoding input event codes.

  • hidden_size (int, optional (default = 8)) – The number of features of the hidden state h

  • hidden_output_size (int, optional (default = 8)) – The number of mix features

  • learning_rate (float) – Learning rate for optimization based on SGD. Use torch.optim.Adam by default.

  • weight_decay (float) – Regularization strength for l2 norm; must be a positive float. Smaller values specify weaker regularization.

  • batch_size (int) – Batch size when doing SGD optimization.

  • epochs (int) – Maximum number of iterations taken for the solvers to converge.

  • num_worker (int) – Number of workers used to do dataloading during training.

  • device (str) – The model device.

Notes

1

Ma, F., Chitta, R., Zhou, J., You, Q., Sun, T., & Gao, J. (2017, August). Dipole: Diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks. In Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining (pp. 1903-1911).

fit(train_data, valid_data)[source]

Train model with sequential patient records.

Parameters
  • train_data (SequencePatientBase) – A SequencePatientBase contains patient records where ‘v’ corresponds to visit sequence of different events; ‘y’ corresponds to labels.

  • valid_data (SequencePatientBase) – A SequencePatientBase contains patient records used to make early stopping of the model.

load_model(checkpoint)[source]

Load pretrained model from the disk.

Parameters

checkpoint (str) – The input directory that stores the trained pytorch model and configuration.

predict(test_data)[source]

Predict patient outcomes using longitudinal trial patient sequences.

Parameters

test_data (SequencePatient) – A SequencePatient contains patient records where ‘v’ corresponds to visit sequence of different events.

save_model(output_dir)[source]

Save the pretrained model to the disk.

Parameters

output_dir (str) – The output directory that stores the trained pytorch model and configuration.