import argparse
import os
from collections import OrderedDict
from typing import Iterable, List, Optional, Union
import torch
from torch import nn
import gninatorch
from gninatorch import dataloaders, models, setup, utils
def _rename(key: str) -> str:
"""
Rename GNINA layer to PyTorch layer.
Parameters
----------
key: str
GNINA layer name (in loaded state dict)
Returns
-------
str
PyTorch layer name
Raises
------
RuntimeError
if layer name is unknown
Notes
-----
The PyTorch CNN layers are named similarly to the original Caffe layers. However,
the layer name is prepended with "features.". PyTorch fully connected layers are
called differently.
The Default2017 model has slight different naming convention than Default2018 and
dense models.
"""
# Fix dense model names
if "dense_block" in key:
names = key.split(".")
return f"features.{names[0]}.blocks.{'.'.join(names[1:])}"
# Fix non-dense model names (and first data_enc layer)
elif "conv" in key or "data_enc" in key:
return f"features.{key}"
# Fix default2017 model names
elif "output_fc." in key:
return key.replace("output_fc", "pose.pose_output")
elif "output_fc_aff." in key:
return key.replace("output_fc_aff", "affinity.affinity_output")
# Fix default2018 and dense models
elif "pose_output" in key:
return f"pose.{key}"
elif "affinity_output" in key:
return f"affinity.{key}"
else: # This should never happen
raise RuntimeError(f"Unknown layer name: {key}")
def _load_weights(weights_file: str) -> OrderedDict:
"""
Load weights from file.
Parameters
----------
weights_file: str
Path to weights file
Returns
-------
OrderedDict
Dictionary of weights (renamed according to PyTorch layer names)
"""
weights = torch.load(weights_file)
# Rename Caffe layers according to PyTorch names defined in gninatorch.models
weights_renamed = OrderedDict(
((_rename(key), value) for key, value in weights.items())
)
return weights_renamed
def _load_gnina_model_file(
weights_file: str, num_voxels: int
) -> Union[models.Default2017Affinity, models.Default2018Affinity, models.Dense]:
"""
Load GNINA model from file.
Parameters
----------
weights_file: str
Path to weights file
num_voxels: int
Number of voxels per grid dimension
Raises
------
ValueError
if model name is unknown
Note
----
All GNINA default models perform both pose prediction and binding affinity
prediction.
"""
if "default2017" in weights_file:
# 32 channels: 18 for the ligand (ligmap.old) and 14 for the protein
model: Union[
models.Default2017Affinity, models.Default2018Affinity, models.DenseAffinity
] = models.Default2017Affinity(
input_dims=(35, num_voxels, num_voxels, num_voxels)
)
elif "default2018" in weights_file:
# 28 channels:
# 14 for the ligand (completelig) and 14 for the protein (completerec)
model = models.Default2018Affinity(
input_dims=(28, num_voxels, num_voxels, num_voxels)
)
elif "dense" in weights_file:
# 28 channels:
# 14 for the ligand (completelig) and 14 for the protein (completerec)
model = models.DenseAffinity(
input_dims=(28, num_voxels, num_voxels, num_voxels)
)
else:
raise ValueError(f"Unknown model name: {weights_file}")
weights = _load_weights(weights_file)
model.load_state_dict(weights)
return model
[docs]def load_gnina_model(
gnina_model: str, dimension: float = 23.5, resolution: float = 0.5
):
"""
Load GNINA model.
Parameters
----------
gnina_model: str
GNINA model name
dimension: float
Grid dimension (in Angstrom)
resolution: float
Grid resolution (in Angstrom)
"""
path = os.path.dirname(os.path.abspath(__file__))
gnina_model_file = os.path.join(path, "weights", f"{gnina_model}.pt")
# Fromhttps://github.com/gnina/libmolgrid/include/libmolgrid/grid_maker.h
num_voxels = round(dimension / resolution) + 1
return _load_gnina_model_file(gnina_model_file, num_voxels)
[docs]def load_gnina_models(
model_names: Iterable[str], dimension: float = 23.5, resolution: float = 0.5
):
"""
Load GNINA models.
Parameters
----------
model_names: Iterable[str]
List of GNINA model names
"""
models_list = []
for model_name in model_names:
m = load_gnina_model(model_name, dimension=dimension, resolution=resolution)
models_list.append(m)
return models.GNINAModelEnsemble(models_list)
[docs]def options(args: Optional[List[str]] = None):
"""
Define options and parse arguments.
Parameters
----------
args: Optional[List[str]]
List of command line arguments
"""
parser = argparse.ArgumentParser(
description=" GNINA scoring function",
)
parser.add_argument("input", type=str, help="Input file for inference")
# TODO: Default2017 model needs different ligand types
parser.add_argument(
"--cnn",
type=str,
help="Pre-trained CNN Model",
default="default",
choices=[f"crossdock_default2018{tag}" for tag in ["", "_ensemble"]]
+ [f"crossdock_default2018_{i}" for i in range(1, 5)]
+ [f"general_default2018{tag}" for tag in ["", "_ensemble"]]
+ [f"general_default2018_{i}" for i in range(1, 5)]
+ [f"redock_default2018{tag}" for tag in ["", "_ensemble"]]
+ [f"redock_default2018_{i}" for i in range(1, 5)]
+ [f"dense{tag}" for tag in ["", "_ensemble"]]
+ [f"dense_{i}" for i in range(1, 5)]
+ ["default"],
)
parser.add_argument(
"-d",
"--data_root",
type=str,
default="",
help="Root folder for relative paths in train files",
)
parser.add_argument("-g", "--gpu", type=str, default="cuda:0", help="Device name")
parser.add_argument("--dimension", type=float, default=23.5, help="Grid dimension")
parser.add_argument("--resolution", type=float, default=0.5, help="Grid resolution")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument(
"--ligmolcache",
type=str,
default="",
help=".molcache2 file for ligands",
)
parser.add_argument(
"--recmolcache",
type=str,
default="",
help=".molcache2 file for receptors",
)
parser.add_argument(
"--no_cache",
action="store_false",
help="Disable structure caching",
dest="cache_structures",
)
return parser.parse_args(args)
[docs]def setup_gnina_model(
cnn: str = "default", dimension: float = 23.5, resolution: float = 0.5
) -> Union[nn.Module, bool]:
"""
Load model or ensemble of models.
Parameters
----------
cnn: str
CNN model name
dimension: float
Grid dimension
resolution: float
Grid resolution
Returns
-------
nn.Module
Model or ensemble of models
Notes
-----
Mimicks GNINA CLI. The model is returned in evaluation mode. This is essential to
use the dense model correctly (due to the :code:`nn.BatchNorm` layers).
"""
ensemble: bool = True
if cnn == "default":
# GNINA default model
# See McNutt et al. J Cheminform (2021) 13:43 for details
names = [
"dense",
"general_default2018_3",
"dense_3",
"crossdock_default2018",
"redock_default2018_2",
]
model = load_gnina_models(names, dimension, resolution)
elif "ensemble" in cnn:
ensemble = True
name = cnn.replace("_ensemble", "")
names = [name] + [f"{name}_{i}" for i in range(1, 5)]
# Load model as an ensemble
model = load_gnina_models(names, dimension, resolution)
else:
ensemble = False
model = load_gnina_model(cnn, dimension, resolution)
# Put model in evaluation mode
# This is essential to have the BatchNorm layers in the correct state
model.eval()
return model, ensemble
[docs]def main(args):
"""
Run inference with GNINA pre-trained models.
Parameters
----------
args: Namespace
Parsed command line arguments
Notes
-----
Models are used in evaluation mode, which is essential for the dense models since
they use batch normalisation.
"""
model, ensemble = setup_gnina_model(args.cnn, args.dimension, args.resolution)
model.eval() # Ensure models are in evaluation mode!
device = utils.set_device(args.gpu)
model.to(device)
example_provider = setup.setup_example_provider(args.input, args, training=False)
grid_maker = setup.setup_grid_maker(args)
# TODO: Allow average over different rotations
loader = dataloaders.GriddedExamplesLoader(
example_provider=example_provider,
grid_maker=grid_maker,
random_translation=0.0, # No random translations for inference
random_rotation=False, # No random rotations for inference
device=device,
grids_only=True,
)
for batch in loader:
if not ensemble:
log_pose, affinity = model(batch)
else:
log_pose, affinity, affinity_var = model(batch)
pose = torch.exp(log_pose[:, -1])
for i, (p, a) in enumerate(zip(pose, affinity)):
print(f"CNNscore: {p:.5f}")
print(f"CNNaffinity: {a:.5f}")
if ensemble:
print(f"CNNvariance: {affinity_var[i]:.5f}")
print("")
def _header():
"""
Print GNINA header.
Notes
-----
The header includes an ASCII art logo, and the relevant references.
"""
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
logo_file = os.path.join(path, "logo")
with open(logo_file, "r") as f:
logo = f.read()
into_file = os.path.join(path, "intro")
with open(into_file, "r") as f:
intro = f.read()
print(logo, "\n\n", intro)
print(f"Version: {gninatorch.__version__} ({gninatorch.__git_revision__})\n")
if __name__ == "__main__":
_header()
args = options()
main(args)