Source code for gninatorch.models

"""
GNINA Caffe models translated to PyTorch.

Notes
-----
The PyTorch models try to follow the original Caffe models as much as possible. However,
some changes are necessary.

Notable differences:
* The :code:`MolDataLayer` is now separated from the model and the parameters are
controlled by CLI arguments in the training process.
* The model output for pose prediction corresponds to the log softmax of the last fully-
connected layer instead of the softmax.
"""

from collections import OrderedDict, namedtuple
from typing import List, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn


[docs] def weights_and_biases_init(m: nn.Module) -> None: """ Initialize the weights and biases of the model. Parameters ---------- m : nn.Module Module (layer) to initialize Notes ----- This function is used to initialize the weights of the model for both convolutional and linear layers. Weights are initialized using uniform Xavier initialization while biases are set to zero. https://github.com/gnina/libmolgrid/blob/e6d5f36f1ae03f643ca69cdec1625ac52e653f88/test/test_torch_cnn.py#L45-L48 """ if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight.data) nn.init.constant_(m.bias.data, 0.0)
[docs] class Default2017(nn.Module): """ GNINA default2017 model architecture. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) Notes ----- This architecture was translated from the following Caffe model: https://github.com/gnina/models/blob/master/crossdocked_paper/default2017.model The main difference is that the PyTorch implementation resurns the log softmax. """ def __init__(self, input_dims: Tuple): super().__init__() assert ( len(input_dims) == 4 ), "Input dimensions must be (channels, depth, height, width)" self.input_dims = input_dims self.features = nn.Sequential( OrderedDict( [ # unit1 ("unit1_pool", nn.MaxPool3d(kernel_size=2, stride=2)), ( "unit1_conv1", nn.Conv3d( in_channels=input_dims[0], out_channels=32, kernel_size=3, stride=1, padding=1, ), ), ("unit1_relu1", nn.ReLU()), # unit2 ("unit2_pool", nn.MaxPool3d(kernel_size=2, stride=2)), ( "unit2_conv1", nn.Conv3d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, ), ), ("unit2_relu1", nn.ReLU()), # unit3 ("unit3_pool", nn.MaxPool3d(kernel_size=2, stride=2)), ( "unit3_conv1", nn.Conv3d( in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, ), ), ("unit3_relu1", nn.ReLU()), ] ) ) self.features_out_size = ( input_dims[1] // 8 * input_dims[2] // 8 * input_dims[3] // 8 * 128 )
[docs] def forward(self, x: torch.Tensor): raise NotImplementedError
[docs] class Default2017Pose(Default2017): """ GNINA default2017 model architecture for pose prediction. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) Notes ----- This architecture was translated from the following Caffe model: https://github.com/gnina/models/blob/master/crossdocked_paper/default2017.model The main difference is that the PyTorch implementation resurns the log softmax of the final linear layer instead of feeding it to a :code:`SoftmaxWithLoss` layer. """ def __init__(self, input_dims: Tuple): super().__init__(input_dims) # Linear layer for pose prediction self.pose = nn.Sequential( OrderedDict( [ ( "pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) )
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- torch.Tensor Log probabilities for ligand pose Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) x = x.view(-1, self.features_out_size) pose_raw = self.pose(x) pose_log = F.log_softmax(pose_raw, dim=1) return pose_log
[docs] class Default2017Affinity(Default2017Pose): """ GNINA default2017 model architecture for pose and affinity prediction. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) Notes ----- This architecture was translated from the following Caffe model: https://github.com/gnina/models/blob/master/crossdocked_paper/default2017.model The main difference is that the PyTorch implementation resurns the log softmax of the final linear layer instead of feeding it to a :code:`SoftmaxWithLoss` layer. """ def __init__(self, input_dims: Tuple): super().__init__(input_dims) # Linear layer for binding affinity prediction self.affinity = nn.Sequential( OrderedDict( [ ( "affinity_output", nn.Linear(in_features=self.features_out_size, out_features=1), ) ] ) )
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- Tuple[torch.Tensor, torch.Tensor] Log probabilities for ligand pose and affinity prediction Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) x = x.view(-1, self.features_out_size) pose_raw = self.pose(x) pose_log = F.log_softmax(pose_raw, dim=1) affinity = self.affinity(x) # Squeeze last (dummy) dimension of affinity prediction # This allows to match the shape (batch_size,) of the target tensor return pose_log, affinity.squeeze(-1)
[docs] class Default2017Flex(Default2017): """ GNINA default2017 model architecture for multi-task pose prediction (ligand and flexible residues). Poses are annotated based on both ligand RMSD and flexible residues RMSD (w.r.t. the cognate receptor in the case of cross-docking). Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) """ def __init__(self, input_dims: Tuple): super().__init__(input_dims) # Linear layer for ligand pose prediction self.lig_pose = nn.Sequential( OrderedDict( [ ( "lig_pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) ) # Linear layer for flexible residues pose prediction self.flex_pose = nn.Sequential( OrderedDict( [ ( "flex_pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) )
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- Tuple[torch.Tensor, torch.Tensor] Log probabilities for ligand pose and flexible residues pose prediction Returns ------- Tuple[torch.Tensor, torch.Tensor] Log probabilities for ligand pose and flexible residues pose prediction Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) x = x.view(-1, self.features_out_size) lig_pose_raw = self.lig_pose(x) lig_pose_log = F.log_softmax(lig_pose_raw, dim=1) flex_pose_raw = self.flex_pose(x) flex_pose_log = F.log_softmax(flex_pose_raw, dim=1) return lig_pose_log, flex_pose_log
[docs] class Default2018(nn.Module): """ GNINA default2017 model architecture. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) Notes ----- This architecture was translated from the following Caffe model: https://github.com/gnina/models/blob/master/crossdocked_paper/default2018.model The main difference is that the PyTorch implementation resurns the log softmax. """ def __init__(self, input_dims: Tuple): super().__init__() assert ( len(input_dims) == 4 ), "Input dimensions must be (channels, depth, height, width)" self.input_dims = input_dims self.features = nn.Sequential( OrderedDict( [ # unit1 ("unit1_pool", nn.AvgPool3d(kernel_size=2, stride=2)), ( "unit1_conv", nn.Conv3d( in_channels=input_dims[0], out_channels=32, kernel_size=3, stride=1, padding=1, ), ), ("unit1_func", nn.ReLU()), # unit2 ( "unit2_conv", nn.Conv3d( in_channels=32, out_channels=32, kernel_size=1, stride=1, padding=0, ), ), ("unit2_func", nn.ReLU()), # unit3 ("unit3_pool", nn.AvgPool3d(kernel_size=2, stride=2)), ( "unit3_conv", nn.Conv3d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, ), ), ("unit3_func", nn.ReLU()), # unit4 ( "unit4_conv", nn.Conv3d( in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0, ), ), ("unit4_func", nn.ReLU()), # unit5 ("unit5_pool", nn.AvgPool3d(kernel_size=2, stride=2)), ( "unit5_conv", nn.Conv3d( in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, ), ), ("unit5_func", nn.ReLU()), ] ) ) self.features_out_size = ( input_dims[1] // 8 * input_dims[2] // 8 * input_dims[3] // 8 * 128 )
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x: torch.Tensor Input tensor """ raise NotImplementedError
[docs] class Default2018Pose(Default2018): """ GNINA default2017 model architecture for pose prediction. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) Notes ----- This architecture was translated from the following Caffe model: https://github.com/gnina/models/blob/master/crossdocked_paper/default2018.model The main difference is that the PyTorch implementation resurns the log softmax. """ def __init__(self, input_dims: Tuple): super().__init__(input_dims) # Linear layer for pose prediction self.pose = nn.Sequential( OrderedDict( [ ( "pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) )
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- torch.Tensor Log probabilities for ligand pose Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) x = x.view(-1, self.features_out_size) pose_raw = self.pose(x) pose_log = F.log_softmax(pose_raw, dim=1) return pose_log
[docs] class Default2018Affinity(Default2018Pose): """ GNINA default2017 model architecture for pose and affinity prediction. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) Notes ----- This architecture was translated from the following Caffe model: https://github.com/gnina/models/blob/master/crossdocked_paper/default2018.model The main difference is that the PyTorch implementation resurns the log softmax. """ def __init__(self, input_dims: Tuple): super().__init__(input_dims) # Linear layer for binding affinity prediction self.affinity = nn.Sequential( OrderedDict( [ ( "affinity_output", nn.Linear(in_features=self.features_out_size, out_features=1), ) ] ) )
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- Tuple[torch.Tensor, torch.Tensor] Log probabilities for ligand pose and affinity prediction Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) x = x.view(-1, self.features_out_size) pose_raw = self.pose(x) pose_log = F.log_softmax(pose_raw, dim=1) affinity = self.affinity(x) # Squeeze last (dummy) dimension of affinity prediction # This allows to match the shape (batch_size,) of the target tensor return pose_log, affinity.squeeze(-1)
[docs] class Default2018Flex(Default2018): """ GNINA default2017 model architecture for multi-task pose prediction (ligand and flexible residues). Poses are annotated based on both ligand RMSD and flexible residues RMSD (w.r.t. the cognate receptor in the case of cross-docking). Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) """ def __init__(self, input_dims: Tuple): super().__init__(input_dims) # Linear layer for ligand pose prediction self.lig_pose = nn.Sequential( OrderedDict( [ ( "lig_pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) ) # Linear layer for flexible residues pose prediction self.flex_pose = nn.Sequential( OrderedDict( [ ( "flex_pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) )
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- Tuple[torch.Tensor, torch.Tensor] Log probabilities for ligand pose and flexible residues pose prediction Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) x = x.view(-1, self.features_out_size) lig_pose_raw = self.lig_pose(x) lig_pose_log = F.log_softmax(lig_pose_raw, dim=1) flex_pose_raw = self.flex_pose(x) flex_pose_log = F.log_softmax(flex_pose_raw, dim=1) return lig_pose_log, flex_pose_log
[docs] class DenseBlock(nn.Module): """ DenseBlock for Dense model. Parameters ---------- in_features: int Input features for the first layer num_block_features: int Number of output features (channels) for the convolutional layers num_block_convs: int Number of convolutions tag: Union[int, str] Tag identifying the DenseBlock Notes ----- The total number of output features corresponds to the input features concatenated together with all subsequent :code:`num_block_features` produced by the convolutional layers (:code:`num_block_convs` times). """ def __init__( self, in_features: int, num_block_features: int = 16, num_block_convs: int = 4, tag: Union[int, str] = "", ) -> None: super().__init__() dense_dict: OrderedDict[str, nn.Module] = OrderedDict() self.in_features = in_features self.num_block_features = num_block_features self.num_block_convs = num_block_convs in_features_layer = in_features for idx in range(num_block_convs): dense_dict.update( [ ( f"data_enc_level{tag}_batchnorm_conv{idx}", nn.BatchNorm3d( in_features_layer, affine=True, # Same effect as "Scale" layer in Caffe ), ), ( f"data_enc_level{tag}_conv{idx}", nn.Conv3d( in_channels=in_features_layer, out_channels=num_block_features, kernel_size=3, padding=1, ), ), (f"data_enc_level{tag}_conv{idx}_relu", nn.ReLU()), ] ) # The next layer takes all previous features as input in_features_layer += num_block_features self.blocks = nn.Sequential(dense_dict)
[docs] def out_features(self) -> int: return self.in_features + self.num_block_features * self.num_block_convs
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- torch.Tensor Output tensor """ # TODO: Make more efficient by keeping concatenated outputs # Store output of previous layers # Used as input of next layer outputs = [x] for block in self.blocks: # Forward propagation to single block x = block(x) if isinstance(block, nn.ReLU): # Store current block output outputs.append(x) # Concatenate all previous outputs as next input # Concatenate on channels x = torch.cat(outputs, dim=1) return x
[docs] class Dense(nn.Module): """ GNINA Dense model architecture. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) num_blocks: int Number of dense blocks num_block_features: int Number of features in dense block convolutions num_block_convs" int Number of convolutions in dense block Notes ----- Original implementation by Andrew McNutt available here: https://github.com/gnina/models/blob/master/pytorch/dense_model.py The main difference is that the original implementation returns the raw output of the last linear layer while here the output is the log softmax of the last linear. """ def __init__( self, input_dims: Tuple, num_blocks: int = 3, num_block_features: int = 16, num_block_convs: int = 4, affinity: bool = True, ) -> None: super().__init__() assert ( len(input_dims) == 4 ), "Input dimensions must be (channels, depth, height, width)" self.input_dims = input_dims features: OrderedDict[str, nn.Module] = OrderedDict( [ ("data_enc_init_pool", nn.MaxPool3d(kernel_size=2, stride=2)), ( "data_enc_init_conv", nn.Conv3d( in_channels=input_dims[0], out_channels=32, kernel_size=3, stride=1, padding=1, ), ), ("data_enc_init_conv_relu", nn.ReLU()), ] ) out_features: int = 32 for idx in range(num_blocks - 1): in_features = out_features # Dense block features[f"dense_block_{idx}"] = DenseBlock( in_features, num_block_features=num_block_features, num_block_convs=num_block_convs, tag=idx, ) # Number of output features from dense block out_features = features[f"dense_block_{idx}"].out_features() features[f"data_enc_level{idx}_bottleneck"] = nn.Conv3d( in_channels=out_features, out_channels=out_features, kernel_size=1, padding=0, ) features[f"data_enc_level{idx}_bottleneck_relu"] = nn.ReLU() features[f"data_enc_level{idx+1}_pool"] = nn.MaxPool3d( kernel_size=2, stride=2 ) in_features = out_features features[f"dense_block_{num_blocks-1}"] = DenseBlock( in_features, num_block_features=num_block_features, num_block_convs=num_block_convs, tag=num_blocks - 1, ) # Final number of channels self.features_out_size = features[f"dense_block_{num_blocks-1}"].out_features() # Final spatial dimensions (pre-global pooling) D = input_dims[1] // 2**num_blocks H = input_dims[2] // 2**num_blocks W = input_dims[3] // 2**num_blocks # Global MAX pooling # Redices spatial dimension to a single number per channel features[f"data_enc_level{num_blocks-1}_global_pool"] = nn.MaxPool3d( kernel_size=((D, H, W)) ) self.features = nn.Sequential(features)
[docs] def forward(self, x): """ Parameters ---------- x: torch.Tensor Input tensor Raises ------ NotImplementedError Notes ----- The forward pass needs to be implemented in derived classes. """ raise NotImplementedError
[docs] class DensePose(Dense): """ GNINA Dense model architecture for pose prediction. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) num_blocks: int Number of dense blocks num_block_features: int Number of features in dense block convolutions num_block_convs" int Number of convolutions in dense block Notes ----- Original implementation by Andrew McNutt available here: https://github.com/gnina/models/blob/master/pytorch/dense_model.py The main difference is that the original implementation resurns the raw output of the last linear layer while here the output is the log softmax of the last linear. """ def __init__( self, input_dims: Tuple, num_blocks: int = 3, num_block_features: int = 16, num_block_convs: int = 4, ) -> None: super().__init__(input_dims, num_blocks, num_block_features, num_block_convs) # Linear layer for binding pose prediction self.pose = nn.Sequential( OrderedDict( [ ( "pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) )
[docs] def forward(self, x): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- torch.Tensor Log probabilities for ligand pose Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) # Reshape based on number of channels # Global max pooling reduced spatial dimensions to single value x = x.view(-1, self.features_out_size) pose_raw = self.pose(x) pose_log = F.log_softmax(pose_raw, dim=1) return pose_log
[docs] class DenseAffinity(DensePose): """ GNINA Dense model architecture for binding affinity prediction. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) num_blocks: int Number of dense blocks num_block_features: int Number of features in dense block convolutions num_block_convs" int Number of convolutions in dense block Notes ----- Original implementation by Andrew McNutt available here: https://github.com/gnina/models/blob/master/pytorch/dense_model.py The main difference is that the original implementation resurns the raw output of the last linear layer while here the output is the log softmax of the last linear. """ def __init__( self, input_dims: Tuple, num_blocks: int = 3, num_block_features: int = 16, num_block_convs: int = 4, ) -> None: super().__init__(input_dims, num_blocks, num_block_features, num_block_convs) # Linear layer for binding affinity prediction self.affinity = nn.Sequential( OrderedDict( [ ( "affinity_output", nn.Linear(in_features=self.features_out_size, out_features=1), ) ] ) )
[docs] def forward(self, x): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- Tuple[torch.Tensor, torch.Tensor] Log probabilities for ligand pose and affinity prediction Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) # Reshape based on number of channels # Global max pooling reduced spatial dimensions to single value x = x.view(-1, self.features_out_size) pose_raw = self.pose(x) pose_log = F.log_softmax(pose_raw, dim=1) affinity = self.affinity(x) # Squeeze last (dummy) dimension of affinity prediction # This allows to match the shape (batch_size,) of the target tensor return pose_log, affinity.squeeze(-1)
[docs] class DenseFlex(Dense): """ GNINA dense model architecture for multi-task pose prediction (ligand and flexible residues). Poses are annotated based on both ligand RMSD and flexible residues RMSD (w.r.t. the cognate receptor in the case of cross-docking). Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) num_blocks: int Number of dense blocks num_block_features: int Number of features in dense block convolutions num_block_convs" int Number of convolutions in dense block Notes ----- Original implementation by Andrew McNutt available here: https://github.com/gnina/models/blob/master/pytorch/dense_model.py The main difference is that the original implementation resurns the raw output of the last linear layer while here the output is the log softmax of the last linear. """ def __init__( self, input_dims: Tuple, num_blocks: int = 3, num_block_features: int = 16, num_block_convs: int = 4, ) -> None: super().__init__(input_dims, num_blocks, num_block_features, num_block_convs) # Linear layer for ligand pose prediction self.lig_pose = nn.Sequential( OrderedDict( [ ( "lig_pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) ) # Linear layer for flexible residues pose prediction self.flex_pose = nn.Sequential( OrderedDict( [ ( "flex_pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) )
[docs] def forward(self, x): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- Tuple[torch.Tensor, torch.Tensor] Log probabilities for ligand pose and flexible residues pose prediction Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) # Reshape based on number of channels # Global max pooling reduced spatial dimensions to single value x = x.view(-1, self.features_out_size) lig_pose_raw = self.lig_pose(x) lig_pose_log = F.log_softmax(lig_pose_raw, dim=1) flex_pose_raw = self.flex_pose(x) flex_pose_log = F.log_softmax(flex_pose_raw, dim=1) return lig_pose_log, flex_pose_log
[docs] class HiResPose(nn.Module): """ GNINA HiResPose model architecture. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) Notes ----- This architecture was translated from the following Caffe model: https://github.com/gnina/models/blob/master/crossdocked_paper/hires_pose.model The main difference is that the PyTorch implementation resurns the log softmax. This model is implemented only for multi-task pose and affinity prediction. """ def __init__(self, input_dims: Tuple): super().__init__() self.input_dims = input_dims self.features = nn.Sequential( OrderedDict( [ # unit1 ( "unit1_conv", nn.Conv3d( in_channels=input_dims[0], out_channels=32, kernel_size=3, stride=1, padding=1, ), ), ("unit1_func", nn.ReLU()), # unit2 ("unit2_pool", nn.MaxPool3d(kernel_size=2, stride=2)), ( "unit2_conv", nn.Conv3d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, ), ), ("unit2_func", nn.ReLU()), # unit3 ("unit3_pool", nn.MaxPool3d(kernel_size=2, stride=2)), ( "unit3_conv", nn.Conv3d( in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, ), ), ("unit3_func", nn.ReLU()), ] ) ) # Two MaxPool3d layers with kernel_size=2 and stride=2 # Spatial dimensions are halved at each pooling step self.features_out_size = ( input_dims[1] // 4 * input_dims[2] // 4 * input_dims[3] // 4 * 128 ) # Linear layer for pose prediction self.pose = nn.Sequential( OrderedDict( [ ( "pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) ) # Linear layer for binding affinity prediction self.affinity = nn.Sequential( OrderedDict( [ ( "affinity_output", nn.Linear(in_features=self.features_out_size, out_features=1), ) ] ) )
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- Tuple[torch.Tensor, torch.Tensor] Log probabilities for ligand pose and affinity prediction Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) # Reshape based on number of channels # Global max pooling reduced spatial dimensions to single value x = x.view(-1, self.features_out_size) pose_raw = self.pose(x) pose_log = F.log_softmax(pose_raw, dim=1) affinity = self.affinity(x) # Squeeze last (dummy) dimension of affinity prediction # This allows to match the shape (batch_size,) of the target tensor return pose_log, affinity.squeeze(-1)
[docs] class HiResAffinity(nn.Module): """ GNINA HiResAffinity model architecture. Parameters ---------- input_dims: tuple Model input dimensions (channels, depth, height, width) Notes ----- This architecture was translated from the following Caffe model: https://github.com/gnina/models/blob/master/crossdocked_paper/hires_pose.model The main difference is that the PyTorch implementation resurns the log softmax. This model is implemented only for multi-task pose and affinity prediction. """ def __init__(self, input_dims: Tuple): super().__init__() self.input_dims = input_dims self.features = nn.Sequential( OrderedDict( [ # unit1 ( "unit1_conv", nn.Conv3d( in_channels=input_dims[0], out_channels=32, kernel_size=3, stride=1, padding=1, ), ), ("unit1_func", nn.ReLU()), # unit2 ( "unit2_conv", nn.Conv3d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, ), ), ("unit2_func", nn.ReLU()), # unit3 ("unit3_pool", nn.AvgPool3d(kernel_size=8, stride=8)), ( "unit3_conv", nn.Conv3d( in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2, ), ), ("unit3_func", nn.ELU(alpha=1.0)), # unit5 (following original naming convention) ("unit5_pool", nn.MaxPool3d(kernel_size=4, stride=4)), ] ) ) self.features_out_size = ( input_dims[1] // (8 * 4) * input_dims[2] // (8 * 4) * input_dims[3] // (8 * 4) * 128 ) # Linear layer for pose prediction self.pose = nn.Sequential( OrderedDict( [ ( "pose_output", nn.Linear(in_features=self.features_out_size, out_features=2), ) ] ) ) # Linear layer for binding affinity prediction self.affinity = nn.Sequential( OrderedDict( [ ( "affinity_output", nn.Linear(in_features=self.features_out_size, out_features=1), ) ] ) )
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- Tuple[torch.Tensor, torch.Tensor] Log probabilities for ligand pose and affinity prediction Notes ----- The pose score is the log softmax of the output of the last linear layer. """ x = self.features(x) # Reshape based on number of channels # Global max pooling reduced spatial dimensions to single value x = x.view(-1, self.features_out_size) pose_raw = self.pose(x) pose_log = F.log_softmax(pose_raw, dim=1) affinity = self.affinity(x) # Squeeze last (dummy) dimension of affinity prediction # This allows to match the shape (batch_size,) of the target tensor return pose_log, affinity.squeeze(-1)
Model = namedtuple("Model", ["model", "affinity", "flex"]) # Key: model name, affinity, flexible residues models_dict = { Model("default2017", False, False): Default2017Pose, Model("default2017", True, False): Default2017Affinity, Model("default2017", False, True): Default2017Flex, Model("default2018", False, False): Default2018Pose, Model("default2018", True, False): Default2018Affinity, Model("default2018", False, True): Default2018Flex, Model("dense", False, False): DensePose, Model("dense", True, False): DenseAffinity, Model("dense", False, True): DenseFlex, Model("hires_pose", True, False): HiResPose, Model("hires_affinity", True, False): HiResAffinity, }
[docs] class GNINAModelEnsemble(nn.Module): """ Ensemble of GNINA models. Parameters ---------- models: List[nn.Module] List of models to use in the ensemble Notes ----- Assume models perform only pose AND affinity prediction. Modules are stored in :code:`nn.ModuleList` so that they are properly registered. """ def __init__(self, models: List[nn.Module]): super().__init__() # Check that all models allow both pose and affinity predictions # These are the only models supported by GNINA so far for m in models: assert ( isinstance(m, Default2017Affinity) or isinstance(m, Default2018Affinity) or isinstance(m, DenseAffinity) ) # nn.ModuleList allows to register the different modules # This makes things like .to(device) apply to all modules self.models = nn.ModuleList(models)
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x: torch.Tensor Input tensor Returns ------- Tuple[torch.tensor, torch.tensor, torch.tensor], Logarithm of the pose score, affinity prediction (average) and affinity variance Notes ----- For pose prediction, the average has to be performed on the scores, not theeir logarithm (returned by the model). In order to be consistent with everywhere else (where the logarighm of the prediction is returned), here we compute the score (by exponentating), compute the average, and finally return the logarithm of the computed average. """ predictions = [model(x) for model in self.models] # map(list, zip(*predictions)) transform list of multi-task predictions into # list of predictions for each task # [(log_pose_1, affinity_1), (log_pose_2, affinity_2), ...] => # [[log_pose_1, log_pose_2, ...], [affinity_1, affinity_2, ...]] # Suggested by @IAlibay # TODO: Better way to do this? log_pose_all, affinity_all = tuple(map(list, zip(*predictions))) affinity_stacked = torch.stack(affinity_all) log_pose_avg = torch.stack(log_pose_all).exp().mean(dim=0).log() affinity_avg = affinity_stacked.mean(dim=0) affinity_var = affinity_stacked.var(dim=0, unbiased=False) return log_pose_avg, affinity_avg, affinity_var