Source code for gninatorch.losses

import torch
import torch.nn as nn
from torch import Tensor


[docs]class ScaledNLLLoss(nn.Module): """ Scaled NLLLoss. Parameters ---------- scale: float Scaling factor for the loss reduction: str Reduction method (mean or sum) """ def __init__( self, scale: float = 1.0, reduction: str = "mean", ): super().__init__() self.scale = scale self.loss = nn.NLLLoss(reduction=reduction)
[docs] def forward(self, input: Tensor, target: Tensor) -> Tensor: """ Parameters ---------- input: Tensor Predicted values target: Tensor Target values Returns ------- torch.Tensor Loss """ return self.scale * self.loss(input, target)
[docs]class AffinityLoss(nn.Module): """ GNINA affinity loss. Parameters ---------- reduction: str Reduction method (mean or sum) delta: float Scaling factor penalty: float Penalty factor pseudo_huber: bool Use pseudo-huber loss as opposed to L2 loss scale: float Scaling factor for the loss Notes ----- Translated from the original custom Caffe layer. Not all functionality is implemented. https://github.com/gnina/gnina/blob/master/caffe/src/caffe/layers/affinity_loss_layer.cpp The :code:`scale` parameter is different from the original implementation. In the original Caffe implementation, the :code:`scale` parameter is used to scale the gradients in the backward pass. Here the scale parameter scales the loss function directly in the forward pass. Definition of pseudo-Huber loss: https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function """ def __init__( self, reduction: str = "mean", delta: float = 1.0, penalty: float = 0.0, pseudo_huber: bool = False, scale: float = 1.0, ): super().__init__() self.delta: float = delta self.delta2: float = delta * delta self.penalty: float = penalty self.pseudo_huber: bool = pseudo_huber self.scale: float = scale assert reduction in ["mean", "sum"] self.reduction: str = reduction
[docs] def forward(self, input: Tensor, target: Tensor) -> Tensor: """ Parameters ---------- input: Tensor Predicted values target: Tensor Target values Returns ------- torch.Tensor Loss Notes ----- Binding affinity (pK) is positive for good poses and negative for bad poses (and zero if unknown). This allows to distinguish good poses from bad poses (to which a penalty is applied) without explicitly using the labels or the RMSD. """ assert input.size() == target.size() # Normal euclidean distance for good poses (positive affinity label) diff = torch.where(target > 0, input - target, torch.zeros_like(input)) # Hinge-like distance for bad poses (negative affinity label) diff = torch.where( torch.logical_and(target < 0, target > -input), input + target + self.penalty, diff, ) if self.pseudo_huber: scaled_diff = diff / self.delta loss = self.delta2 * (torch.sqrt(1.0 + scaled_diff * scaled_diff) - 1.0) else: # L2 loss loss = diff * diff if self.reduction == "mean": reduced_loss = torch.mean(loss) else: # Assertion in init ensures that reduction is "sum" reduced_loss = torch.sum(loss) return self.scale * reduced_loss