indiv_outcome.tabular.base

class pytrial.tasks.indiv_outcome.tabular.base.TabularIndivBase(experiment_id='test')[source]

Abstract class for all individual outcome predictions based on tabular patient data.

Parameters

experiment_id (str, optional (default = 'test')) – The name of current experiment.

eval(mode=False)[source]

Set the model in evaluation mode. Work similar to model.eval() in PyTorch.

Parameters

mode (bool, optional (default = False)) – Whether to set the model in evaluation mode. False means the model is in evaluation mode. True means the model is in training mode.

abstract fit(train_data, valid_data)[source]

Fit function needs to be implemented after subclass.

Parameters
  • train_data (Any) – Training data.

  • valid_data (Any) – Validation data.

abstract load_model(checkpoint)[source]

Load the pretrained model from disk, needs to be implemented after subclass.

Parameters

checkpoint (str) – The path to the checkpoint file.

abstract predict(test_data)[source]

Prediction function needs to be implemented after subclass.

Parameters

test_data (Any) – Testing data.

abstract save_model(output_dir)[source]

Save the model to disk, needs to be implemented after subclass.

Parameters

output_dir (str) – The path to the output directory.

train(mode=True)[source]

Set the model in training mode. Work similar to model.train() in PyTorch.

Parameters

mode (bool, optional (default = True)) – Whether to set the model in training mode. False means the model is in evaluation mode. True means the model is in training mode.