Source code for gninatorch.transforms

"""
PyTorch-Ignite output transformations.

Note
----
PyTorch-Ignite :code:`output_transform` arguments allow to transform the
:code:`Engine.state.output` for the intendend use (by `ignite.metrics` and
`ignite.handlers`).
"""

from typing import Dict, Tuple

import torch


[docs]def output_transform_select_log_pose( output: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Select pose :code:`log_softmax` output and labels from output dictionary. Parameters ---------- output: Dict[str, ignite.metrics.Metric] Engine output Returns ------- Tuple[torch.Tensor, torch.Tensor] Logarithm of the pose class probabilities (:code:`log_softmax`) and class label Notes ----- This function is used as :code:`output_transform` in :class:`ignite.metrics.metric.Metric` and allow to select pose results from the dictionary that the evaluator returns. The output is not activated, i.e. the :code:`log_softmax` output is returned unchanged """ # Return pose log class probabilities and true labels return output["pose_log"], output["labels"]
[docs]def output_transform_select_pose( output: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Select pose :code:`softmax` output and labels from output dictionary. Parameters ---------- output: Dict[str, ignite.metrics.Metric] Engine output Returns ------- Tuple[torch.Tensor, torch.Tensor] Class probabilities and class labels Notes ----- This function is used as :code:`output_transform` in :class:`ignite.metrics.metric.Metric` and allow to select pose results from the dictionary that the evaluator returns. The output is activated, i.e. the :code:`log_softmax` output is transformed into :code:`softmax`. """ # Return pose class probabilities and true labels # log_softmax is transformed into softmax to get the class probabilities return torch.exp(output["pose_log"]), output["labels"]
[docs]def output_transform_select_affinity( output: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Select predicted affinities output and experimental (target) affinities from output dictionary. Parameters ---------- output: Dict[str, ignite.metrics.Metric] Engine output Returns ------- Tuple[torch.Tensor, torch.Tensor] Predicted binding affinity and experimental (target) binding affinity Notes ----- This function is used as :code:`output_transform` in :class:`ignite.metrics.metric.Metric` and allow to select affinity predictions from the dictionary that the evaluator returns. """ # Return pose class probabilities and true labels return output["affinities_pred"], output["affinities"]
[docs]def output_transform_select_affinity_abs( output: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Select predicted affinities (in absolute value) and experimental (target) affinities from output dictionary. Parameters ---------- output: Dict[str, ignite.metrics.Metric] Engine output Returns ------- Tuple[torch.Tensor, torch.Tensor] Predicted binding affinity (absolute value) and experimental binding affinity Notes ----- This function is used as :code:`output_transform` in :class:`ignite.metrics.metric.Metric` and allow to select affinity predictions from the dictionary that the evaluator returns. Affinities can have negative values when they are associated to bad poses. The sign is used by :class:`AffinityLoss`, but in order to compute standard metrics the absolute value is needed, which is returned here. """ # Return pose class probabilities and true labels return output["affinities_pred"], torch.abs(output["affinities"])
[docs]def output_transform_ROC(output) -> Tuple[torch.Tensor, torch.Tensor]: """ Output transform for the ROC curve. Parameters ---------- output: Engine output Returns ------- Tuple[torch.Tensor, torch.Tensor] Positive class probability and associated labels. Notes ----- https://pytorch.org/ignite/generated/ignite.contrib.metrics.ROC_AUC.html#roc-auc """ # Select pose prediction pose, labels = output_transform_select_pose(output) # Return probability estimates of the positive class return pose[:, -1], labels
[docs]def output_transform_select_log_flex( output: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Select flexible residues pose :code:`log_softmax` output and labels from output dictionary. Parameters ---------- output: Dict[str, ignite.metrics.Metric] Engine output Returns ------- Tuple[torch.Tensor, torch.Tensor] Logarithm of the pose class probabilities (:code:`log_softmax`) and class label Notes ----- This function is used as :code:`output_transform` in :class:`ignite.metrics.metric.Metric` and allow to select pose results from the dictionary that the evaluator returns. The output is not activated, i.e. the :code:`log_softmax` output is returned unchanged """ # Return pose log class probabilities and true labels return output["flexpose_log"], output["flexlabels"]
[docs]def output_transform_select_flex( output: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Select flexible residues pose :code:`softmax` output and labels from output dictionary. Parameters ---------- output: Dict[str, ignite.metrics.Metric] Engine output Returns ------- Tuple[torch.Tensor, torch.Tensor] Class probabilities and class labels Notes ----- This function is used as :code:`output_transform` in :class:`ignite.metrics.metric.Metric` and allow to select pose results from the dictionary that the evaluator returns. The output is activated, i.e. the :code:`log_softmax` output is transformed into :code:`softmax`. """ # Return pose class probabilities and true labels # log_softmax is transformed into softmax to get the class probabilities return torch.exp(output["flexpose_log"]), output["flexlabels"]
[docs]def output_transform_ROC_flex(output) -> Tuple[torch.Tensor, torch.Tensor]: """ Output transform for the ROC curve (for flexible residues pose) Parameters ---------- output: Engine output Returns ------- Tuple[torch.Tensor, torch.Tensor] Positive class probability and associated labels. Notes ----- https://pytorch.org/ignite/generated/ignite.contrib.metrics.ROC_AUC.html#roc-auc """ # Select pose prediction flexpose, flexlabels = output_transform_select_flex(output) # Return probability estimates of the positive class return flexpose[:, -1], flexlabels