gninatorch.metrics module

gninatorch.metrics.setup_metrics(affinity: bool, flex: bool, pose_loss: Module, affinity_loss: Module, flexpose_loss: Module, roc_auc: bool, device: device) Dict[str, Any][source]

Define metrics to be computed at the end of an epoch (evaluation).

Parameters:
  • affinity (bool) – Flag for affinity prediction (in addition to ligand pose prediction)

  • flex (bool) – Flag for flexible residues pose prediction (in addition to ligand pose prediction)

  • pose_loss (nn.Module) – Pose loss

  • affinity_loss (nn.Module) – Affinity loss

  • flexpose_loss (nn.Module) – Flexible residues pose loss

  • roc_auc (bool) – Flag for computing ROC AUC

  • device (torch.device) – Device

Returns:

Dictionary of PyTorch Ignite metrics

Return type:

Dict[str, ignite.metrics.Metric]

Notes

The computation of the ROC AUC for pose prediction can be disabled. This is useful when the computation is expected to fail because all poses belong to the same class (e.g. all poses are “good” poses). This situations happens when working with crystal structures, for which the pose is a “good” pose by definition.

Loss functions need to be set up as metrics in order to be correctly accumulated. Using evaluator.state.output to compute the loss does not work since the output only contain the last batch (to avoid RAM saturation).