lenskit.graphs.lightgcn.LightGCNTrainer#
- classlenskit.graphs.lightgcn.LightGCNTrainer(scorer, data, options)#
Bases:
lenskit.training.ModelTrainerProtocol implemented by iterative trainers for models. Models that implement
UsesTrainerwill return an object implementing this protocol from theircreate_trainer()method.This protocol only defines the core aspects of training a model. Trainers should also implement
ParameterContainerto allow training to be checkpointed and resumed.It is also a good idea for the trainer to be pickleable, but the parameter container interface is the primary mechanism for checkpointing.
- Stability:
- Full (see Stability Levels ).
- Parameters:
scorer (LightGCNScorer)
data (lenskit.data.Dataset)
options (lenskit.training.TrainingOptions)
- scorer:LightGCNScorer #
- data:lenskit.data.Dataset #
- options:lenskit.training.TrainingOptions #
- model:torch_geometric.nn.LightGCN#
- edges:torch.Tensor #
- optimizer:torch.optim.Optimizer #
- train_epoch()#
Perform one epoch of the training process, optionally returning metrics on the training behavior. After each training iteration, the mmodel must be usable.
- finalize()#
Finish the training process, cleaning up any unneeded data structures and doing any finalization steps to the model.
The default implementation does nothing.
- abstractmethodbatch_loss(mb_edges, scores)#
- Parameters:
mb_edges (torch.Tensor)
scores (torch.Tensor)
- Return type: