gninatorch.metrics module
- gninatorch.metrics.setup_metrics(affinity: bool, flex: bool, pose_loss: Module, affinity_loss: Module, flexpose_loss: Module, roc_auc: bool, device: device) Dict[str, Any] [source]
Define metrics to be computed at the end of an epoch (evaluation).
- Parameters:
affinity (bool) – Flag for affinity prediction (in addition to ligand pose prediction)
flex (bool) – Flag for flexible residues pose prediction (in addition to ligand pose prediction)
pose_loss (nn.Module) – Pose loss
affinity_loss (nn.Module) – Affinity loss
flexpose_loss (nn.Module) – Flexible residues pose loss
roc_auc (bool) – Flag for computing ROC AUC
device (torch.device) – Device
- Returns:
Dictionary of PyTorch Ignite metrics
- Return type:
Dict[str, ignite.metrics.Metric]
Notes
The computation of the ROC AUC for pose prediction can be disabled. This is useful when the computation is expected to fail because all poses belong to the same class (e.g. all poses are “good” poses). This situations happens when working with crystal structures, for which the pose is a “good” pose by definition.
Loss functions need to be set up as metrics in order to be correctly accumulated. Using
evaluator.state.output
to compute the loss does not work since the output only contain the last batch (to avoid RAM saturation).