utils.trainer
- class pytrial.utils.trainer.Trainer(model: torch.nn.modules.module.Module, train_objectives: List[Tuple[torch.utils.data.dataloader.DataLoader, torch.nn.modules.module.Module]], test_data=None, test_metric=None, less_is_better=False, load_best_at_end=True, n_gpus=1, output_dir='./checkpoints/', **kwargs)[source]
Bases:
object
A general trainer used to train deep learning models.
- Parameters
model (nn.Module) – The model to be trained.
train_objectives (list[tuple[DataLoader, nn.Module]]) – The defined pairs of dataloaders and the loss models.
test_data ((optional) dict or Dataset) – Depending on the implemented get_test_dataloader function. That function receives it as inputs and return test dataloader.
test_metric ((optional) str) – Which test metric is used to judge the best checkpoint during the training. Only used when test_data is given. Should be contained in the returned metric dict by evaluate function.
less_is_better ((optional) bool) – If the test metric is less the better. Ignored if no test_data and test_metric is given.
load_best_at_end (bool) – If load the best checkpoint at the end of training.
n_gpus (int) – How many GPUs used to kick of training. If set larger than 1, parallel training will be used.
output_dir (str) – The intermediate model checkpoints during the training will be dump to under this dir.
Examples
>>> trainer = Trainer( ... model=model, ... train_objectives=[(dataloader1, loss_model1), (dataloader2, loss_model2)], ... ) >>> trainer.train( ... epochs=10, ... )
- evaluate()[source]
Need to be created by specific tasks.
- Returns
- Return type
A dict of computed evalution metrics.
- evaluated = False
- prepare_input(data)[source]
Need to be reimplemented sometimes when input data is not in the standard dict structure.
- train(epochs=10, learning_rate=2e-05, weight_decay=0.0001, warmup_ratio=0, scheduler='warmupcosine', evaluation_steps=10, max_grad_norm=0.5, use_amp=False, **kwargs)[source]
Kick of training using the provided loss model and train dataloaders.
- Parameters
epochs (int (default=10)) – Number of iterations (epochs) over the corpus.
learning_rate (float (default=3e-5)) – The learning rate.
weight_decay (float (default=1e-4)) – Weight decay applied for regularization.
warmup_ratio (float (default=0)) –
How many steps used for warmup training.
If set 0, not warmup.
scheduler ({'constantlr','warmupconstant','warmuplinear','warmupcosine','warmupcosinewithhardrestarts'}) – Pick learning rate scheduler for warmup. Ignored if warmup_ratio <= 0.
evaluation_steps (int (default=10)) –
How many iterations
while we print the training loss and conduct evaluation if evaluator is given.
max_grad_norm (float (default=0.5)) – Clip the gradient to avoid NaN.
use_amp (bool (default=False)) – Whether or not use mixed precision training.