Source code for gninatorch.training

"""
PyTorch implementation of GNINA scoring function's Caffe training script.
"""

import argparse
import os
import sys
from collections import defaultdict
from typing import List, Optional

import ignite
import molgrid
import numpy as np
import pandas as pd
import torch
from ignite.contrib.handlers import ProgressBar
from ignite.contrib.handlers.mlflow_logger import MLflowLogger, global_step_from_engine
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, timing
from torch import nn, optim

from gninatorch import metrics, setup, utils
from gninatorch.dataloaders import GriddedExamplesLoader
from gninatorch.losses import AffinityLoss, ScaledNLLLoss
from gninatorch.models import models_dict, weights_and_biases_init


[docs]def options(args: Optional[List[str]] = None): """ Define options and parse arguments. Parameters ---------- args: Optional[List[str]] List of command line arguments """ parser = argparse.ArgumentParser( description="GNINA scoring function", ) # Data # TODO: Allow multiple train files? parser.add_argument("trainfile", type=str, help="Training file") parser.add_argument("--testfile", type=str, default=None, help="Test file") parser.add_argument( "-d", "--data_root", type=str, default="", help="Root folder for relative paths in train files", ) parser.add_argument( "--balanced", action="store_true", help="Balanced sampling of receptors" ) parser.add_argument( "--no_shuffle", action="store_false", help="Deactivate random shuffling of samples", dest="shuffle", # Variable name (shuffle is False when --no_shuffle is used) ) parser.add_argument( "--label_pos", type=int, default=0, help="Pose label position in training file" ) parser.add_argument( "--affinity_pos", type=int, default=None, help="Affinity value position in training file", ) parser.add_argument( "--flexlabel_pos", type=int, default=None, help="Flexible residues pose label position in training file", ) parser.add_argument( "--stratify_receptor", action="store_true", help="Sample uniformly across receptors", ) parser.add_argument( "--stratify_pos", type=int, default=1, help="Sample uniformly across bins", ) parser.add_argument( "--stratify_max", type=float, default=0, help="Maximum range for value stratification", ) parser.add_argument( "--stratify_min", type=float, default=0, help="Minimum range for value stratification", ) parser.add_argument( "--stratify_step", type=float, default=0, help="Step size for value stratification", ) parser.add_argument( "--ligmolcache", type=str, default="", help=".molcache2 file for ligands", ) parser.add_argument( "--recmolcache", type=str, default="", help=".molcache2 file for receptors", ) parser.add_argument( "-o", "--out_dir", type=str, default=os.getcwd(), help="Output directory" ) parser.add_argument( "--log_file", type=str, default="training.log", help="Log file name" ) # Scoring function parser.add_argument( "-m", "--model", type=str, default="default2017", help="Model name", choices=set([k[0] for k in models_dict.keys()]), # Model names ) parser.add_argument("--dimension", type=float, default=23.5, help="Grid dimension") parser.add_argument("--resolution", type=float, default=0.5, help="Grid resolution") # TODO: ligand type file and receptor type file (default: 28 types) # Learning parser.add_argument( "--base_lr", type=float, default=0.01, help="Base (initial) learning rate" ) parser.add_argument("--momentum", type=float, default=0.9, help="Momentum") parser.add_argument( "--weight_decay", type=float, help="Weight decay", default=0.001 ) parser.add_argument("--batch_size", type=int, default=64, help="Batch size") parser.add_argument( "--no_random_rotation", action="store_false", help="Deactivate random rotation of samples", dest="random_rotation", ) parser.add_argument( "--random_translation", type=float, default=6.0, help="Random translation" ) parser.add_argument( "-i", "--iterations", type=int, default=250000, help="Number of iterations (epochs)", ) parser.add_argument( "--iteration_scheme", type=str, default="small", help="molgrid iteration scheme", choices=setup._iteration_schemes.keys(), ) # lr_dynamic, originally called --dynamic parser.add_argument( "--lr_dynamic", action="store_true", help="Adjust learning rate in response to training", ) # lr_patience, originally called --step_when # Acts on epochs, not on iterations parser.add_argument( "--lr_patience", type=int, default=5, help="Number of epochs without improvement before learning rate update", ) # lr_reduce, originally called --step_reduce parser.add_argument( "--lr_reduce", type=float, default=0.1, help="Learning rate reduction factor" ) # lr_min default value set to match --step_end_cnt default value (3 reductions) parser.add_argument("--lr_min", type=float, default=0.01 * 0.1**3) parser.add_argument( "--clip_gradients", type=float, default=10.0, help="Gradients threshold (for clipping)", ) parser.add_argument( "--pseudo_huber_affinity_loss", action="store_true", help="Use pseudo-Huber loss for affinity loss", ) parser.add_argument( "--delta_affinity_loss", type=float, default=4.0, help="Delta factor for affinity loss", ) parser.add_argument( "--scale_affinity_loss", type=float, default=1.0, help="Scale factor for affinity loss", ) parser.add_argument( "--penalty_affinity_loss", type=float, default=1.0, help="Penalty for affinity loss", ) parser.add_argument( "--scale_pose_loss", type=float, default=1.0, help="Scale factor for pose loss", ) parser.add_argument( "--scale_flexpose_loss", type=float, default=1.0, help="Scale factor for flexible residues pose loss", ) # Misc parser.add_argument( "-t", "--test_every", type=int, default=1000, help="Test interval" ) parser.add_argument( "--checkpoint_every", type=int, default=100, help="Number of epochs per checkpoint", ) parser.add_argument( "--num_checkpoints", type=int, default=1, help="Number of checkpoints to keep" ) parser.add_argument( "--checkpoint_prefix", type=str, default="", help="Checkpoint file prefix" ) parser.add_argument( "--checkpoint_dir", type=str, default="", help="Checkpoint directory (appended to output directory)", ) parser.add_argument("--progress_bar", action="store_true", help="Show progress bar") parser.add_argument("-g", "--gpu", type=str, default="cuda:0", help="Device name") # ROC AUC fails when there is only one class (i.e. all poses are good poses) # This happens when training with crystal structures only parser.add_argument( "--no_roc_auc", action="store_false", help="Disable ROC AUC (useful for crystal poses)", dest="roc_auc", ) parser.add_argument( "--no_cache", action="store_false", help="Disable structure caching", dest="cache_structures", ) parser.add_argument("-s", "--seed", type=int, default=None, help="Random seed") parser.add_argument("--silent", action="store_true", help="No console output") return parser.parse_args(args)
def _train_step_pose( trainer: Engine, batch, model: nn.Module, optimizer, pose_loss: nn.Module, clip_gradients: float, ) -> float: """ Training step for pose prediction. Parameters ---------- trainer: Engine PyTorch Ignite engine for training batch: Batch of data model: PyTorch model optimizer: PyTorch optimizer pose_loss: Loss function for pose prediction clip_gradients: Gradient clipping threshold Returns ------- float Loss Notes ----- Gradients are clipped by norm and not by value. """ model.train() optimizer.zero_grad() # Data is already on the correct device thanks to the ExampleProvider grids, labels = batch pose_log = model(grids) # Compute loss for pose prediction loss = pose_loss(pose_log, labels) loss.backward() # TODO: Double check that gradient clipping by norm corresponds to the Caffe # implementation nn.utils.clip_grad_norm_(model.parameters(), clip_gradients) optimizer.step() return loss.item() def _train_step_pose_and_affinity( trainer: Engine, batch, model: nn.Module, optimizer, pose_loss: nn.Module, affinity_loss: nn.Module, clip_gradients: float, ) -> float: """ Training step for pose and affinity prediction. Parameters ---------- trainer: Engine PyTorch Ignite engine for training batch: Batch of data model: PyTorch model optimizer: PyTorch optimizer pose_loss: Loss function for pose prediction affinity_loss: Loss function for binding affinity prediction clip_gradients: Gradient clipping threshold Returns ------- float Loss Notes ----- Gradients are clipped by norm and not by value. """ model.train() optimizer.zero_grad() # Data is already on the correct device thanks to the ExampleProvider grids, labels, affinities = batch pose_log, affinities_pred = model(grids) # Compute combined loss for pose prediction and affinity prediction loss = pose_loss(pose_log, labels) + affinity_loss(affinities_pred, affinities) loss.backward() # TODO: Double check that gradient clipping by norm corresponds to the Caffe # implementation nn.utils.clip_grad_norm_(model.parameters(), clip_gradients) optimizer.step() return loss.item() def _train_step_flex( trainer: Engine, batch, model: nn.Module, optimizer, pose_loss: nn.Module, flexpose_loss: nn.Module, clip_gradients: float, ) -> float: """ Training step for pose prediction. Parameters ---------- trainer: Engine PyTorch Ignite engine for training batch: Batch of data model: PyTorch model optimizer: PyTorch optimizer pose_loss: Loss function for pose prediction flexpose_loss: Loss function for flexible residues pose prediction clip_gradients: Gradient clipping threshold Returns ------- float Loss Notes ----- Gradients are clipped by norm and not by value. """ model.train() optimizer.zero_grad() # Data is already on the correct device thanks to the ExampleProvider grids, labels, flexlabels = batch pose_log, flexpose_log = model(grids) # Compute loss for pose prediction loss = pose_loss(pose_log, labels) + flexpose_loss(flexpose_log, flexlabels) loss.backward() # TODO: Double check that gradient clipping by norm corresponds to the Caffe # implementation nn.utils.clip_grad_norm_(model.parameters(), clip_gradients) optimizer.step() return loss.item() def _setup_trainer( model, optimizer, pose_loss, affinity_loss, flexpose_loss, clip_gradients: float ) -> Engine: """ Setup training engine for binding pose prediction or binding pose and affinity prediction. Patameters ---------- model: Model to train optimizer: Optimizer pose_loss: Loss function for pose prediction affinity_loss: Loss function for affinity prediction flexpose_loss: Loss function for flexible residues pose prediction clip_gradients: Gradient clipping threshold Notes ----- The arguments :code:`affinity_loss` and :code:`flexpose_loss` determine the type of training to be performed. If :code:`affinity_loss is not None`, multi-task learning on both the ligand pose and the binding affinity is performed using the training function :fun:`_train_step_pose_and_affinity`. If :code:`flexpose_loss is not None`, multi-task learning on both the ligand pose and the pose of the flexible residues is performed using the training function :fun:`_train_step_flex`. """ # Affinity prediction is currently incompatible with flexible residues pose # prediction assert affinity_loss is None or flexpose_loss is None if affinity_loss is not None: # Pose prediction and binding affinity prediction # Create engine based on custom train step trainer = Engine( lambda trainer, batch: _train_step_pose_and_affinity( trainer, batch, model, optimizer, pose_loss=pose_loss, affinity_loss=affinity_loss, clip_gradients=clip_gradients, ) ) elif flexpose_loss is not None: # Ligand and flexible residues pose prediction # Create engine based on custom train step trainer = Engine( lambda trainer, batch: _train_step_flex( trainer, batch, model, optimizer, pose_loss=pose_loss, flexpose_loss=flexpose_loss, clip_gradients=clip_gradients, ) ) else: # Pose prediction and binding affinity prediction # Create engine based on custom train step trainer = Engine( lambda trainer, batch: _train_step_pose( trainer, batch, model, optimizer, pose_loss=pose_loss, clip_gradients=clip_gradients, ) ) return trainer def _evaluation_step_pose_and_affinity(evaluator: Engine, batch, model): """ Evaluate model for binding pose and affinity prediction. Parameters ---------- evaluator: PyTorch Ignite :code:`Engine` batch: Batch data model: Model Returns ------- Tuple[torch.Tensor] Class probabilities for pose prediction, affinity prediction, true pose labels and experimental binding affinities Notes ----- The model returns the log softmax of the last linear layer for binding pose prediction (log class probabilities) and the raw output of the last linear layer for binding affinity predictions. """ model.eval() with torch.no_grad(): grids, labels, affinities = batch pose_log, affinities_pred = model(grids) output = { "pose_log": pose_log, "affinities_pred": affinities_pred, "labels": labels, "affinities": affinities, } return output def _evaluation_step_pose(evaluator: Engine, batch, model): """ Evaluate model for binding pose prediction only. Parameters ---------- evaluator: PyTorch Ignite :code:`Engine` batch: Batch data model: Model Returns ------- Tuple[torch.Tensor] Class probabilities for pose prediction and true pose labels Notes ----- While not strictly necessary (the default PyTorch Ignite evaluator would work well in the case of pose-prediction only), this function is used to return a dictionary of the output with the same key used in :fun:`_evaluation_step_pose_and_affinity`. This allows to simplify the code of the learning rate scheduler function. This function also allows consistency in allowing the use of :fun:`transforms.output_transform_select_pose` for both pose prediction only and binding pose prediction with binding affinity prediction. """ model.eval() with torch.no_grad(): grids, labels = batch pose_log = model(grids) output = { "pose_log": pose_log, "labels": labels, } return output def _evaluation_step_flex(evaluator: Engine, batch, model): """ Evaluate model for ligand and flexible residues pose prediction. Parameters ---------- evaluator: PyTorch Ignite :code:`Engine` batch: Batch data model: Model Returns ------- Tuple[torch.Tensor] Log class probabilities for pose prediction, log class probabilities for flexible residues pose prediction, true pose labels, and true flexible residues pose labels Notes ----- The model returns the log softmax of the last linear layer for binding pose prediction (log class probabilities). """ model.eval() with torch.no_grad(): grids, labels, flexlabels = batch pose_log, flexpose_log = model(grids) output = { "pose_log": pose_log, "flexpose_log": flexpose_log, "labels": labels, "flexlabels": flexlabels, } return output def _setup_evaluator( model, metrics, affinity: bool = False, flex: bool = False ) -> Engine: """ Setup PyTorch Ignite :code:`Engine` for evaluation. Parameters ---------- model: PyTorch model metrics: Evaluation metrics affinity: bool Flag for affinity prediction (in addition to ligand pose prediction) flex: bool Flag for flexible residues pose prediction (in addition to ligand pose prediction) Returns ------- ignite.Engine PyTorch Ignite engine for evaluation """ assert not (affinity and flex) if affinity: evaluator = Engine( lambda evaluator, batch: _evaluation_step_pose_and_affinity( evaluator, batch, model ) ) elif flex: evaluator = Engine( lambda evaluator, batch: _evaluation_step_flex(evaluator, batch, model) ) else: evaluator = Engine( lambda evaluator, batch: _evaluation_step_pose(evaluator, batch, model) ) # Add metrics to the evaluator engine # Metrics need an output_tranform method in order to select the correct output # from _evaluation_step_pose_and_affinity for name, metric in metrics.items(): metric.attach(evaluator, name) return evaluator
[docs]def training(args): """ Main function for training GNINA scoring function. Parameters ---------- args: Command line arguments Notes ----- Training might start off slow because the :code:`molgrid.ExampleProvider` is caching the structures that are read from .gninatypes files. The training then speeds up considerably. """ # Affinity prediction not supported with flexible residues (and vice versa) assert args.affinity_pos is None or args.flexlabel_pos is None # Create necessary directories if not already present os.makedirs(args.out_dir, exist_ok=True) # Define output streams for logging logfilename = os.path.join(args.out_dir, args.log_file) logfile = open(logfilename, "w") if not args.silent: outstreams = [sys.stdout, logfile] else: outstreams = [logfile] mlflogger = MLflowLogger() # Log parameters from argument parser # Add additional parameters params = vars(args) params.update( { "pytorch": torch.__version__, "ignite": ignite.__version__, "cuda": torch.version.cuda if torch.cuda.is_available() else "None", } ) mlflogger.log_params(params) # Print command line arguments for outstream in outstreams: utils.print_args(args, "--- GNINA TRAINING ---", stream=outstream) # Set random seed for reproducibility if args.seed is not None: molgrid.set_random_seed(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) # Set device device = utils.set_device(args.gpu) # Create example providers train_example_provider = setup.setup_example_provider( args.trainfile, args, training=True ) if args.testfile is not None: test_example_provider = setup.setup_example_provider( args.testfile, args, training=False ) # Create grid maker grid_maker = setup.setup_grid_maker(args) train_loader = GriddedExamplesLoader( example_provider=train_example_provider, grid_maker=grid_maker, label_pos=args.label_pos, affinity_pos=args.affinity_pos, flexlabel_pos=args.flexlabel_pos, random_translation=args.random_translation, random_rotation=args.random_rotation, device=device, ) if args.testfile is not None: test_loader = GriddedExamplesLoader( example_provider=test_example_provider, grid_maker=grid_maker, label_pos=args.label_pos, affinity_pos=args.affinity_pos, flexlabel_pos=args.flexlabel_pos, random_translation=args.random_translation, random_rotation=args.random_rotation, device=device, ) assert test_loader.dims == train_loader.dims affinity: bool = args.affinity_pos is not None flex: bool = args.flexlabel_pos is not None # Create model # Select model based on architecture and affinity flag (pose vs affinity) model = models_dict[(args.model, affinity, flex)](train_loader.dims).to(device) model.apply(weights_and_biases_init) # Compile model into TorchScript model = torch.jit.script(model) optimizer = optim.SGD( model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay, ) # Define loss functions pose_loss = torch.jit.script(ScaledNLLLoss(scale=args.scale_pose_loss)) affinity_loss = ( torch.jit.script( AffinityLoss( delta=args.delta_affinity_loss, penalty=args.penalty_affinity_loss, pseudo_huber=args.pseudo_huber_affinity_loss, scale=args.scale_affinity_loss, ) ) if affinity else None ) flexpose_loss = ( torch.jit.script(ScaledNLLLoss(scale=args.scale_flexpose_loss)) if flex else None ) trainer = _setup_trainer( model, optimizer, pose_loss=pose_loss, affinity_loss=affinity_loss, flexpose_loss=flexpose_loss, clip_gradients=args.clip_gradients, ) mlflogger.attach_opt_params_handler( trainer, event_name=Events.ITERATION_STARTED, optimizer=optimizer, param_name="lr", # optional ) allmetrics = metrics.setup_metrics( affinity, flex, pose_loss, affinity_loss, flexpose_loss, args.roc_auc, device ) # Storage for metrics # This is for manual logging of metrics # Metrics are outputted to CSV files in the output folder and to the MLflow logger # TODO: Remove redundancy? CSV files are quite useful... metrics_train = defaultdict(list) metrics_test = defaultdict(list) train_evaluator = _setup_evaluator(model, allmetrics, affinity=affinity, flex=flex) test_evaluator = _setup_evaluator(model, allmetrics, affinity=affinity, flex=flex) mlflogger.attach_output_handler( train_evaluator, event_name=Events.EPOCH_COMPLETED, tag="Train", metric_names=list(allmetrics.keys()), global_step_transform=global_step_from_engine(trainer), # Get training epoch ) mlflogger.attach_output_handler( test_evaluator, event_name=Events.EPOCH_COMPLETED, tag="Test", metric_names=list(allmetrics.keys()), global_step_transform=global_step_from_engine(trainer), # Get training epoch ) # Define LR scheduler if args.lr_dynamic: torch_scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="max", factor=args.lr_reduce, patience=args.lr_patience, min_lr=args.lr_min, verbose=False, ) # Elapsed time timer, training time only elapsed_time = timing.Timer() elapsed_time.attach( trainer, start=Events.STARTED, resume=Events.EPOCH_STARTED, pause=Events.EPOCH_COMPLETED, step=Events.EPOCH_COMPLETED, ) @trainer.on(Events.EPOCH_COMPLETED(every=args.test_every)) def log_training_results(trainer): """ Evaluate metrics on the training set and update the LR according to the loss function, if needed. """ train_evaluator.run(train_loader) for outstream in outstreams: utils.log_print( train_evaluator.state.metrics, title="Train Results", epoch=trainer.state.epoch, epoch_time=trainer.state.times["EPOCH_COMPLETED"], elapsed_time=elapsed_time.total, stream=outstream, ) mts = train_evaluator.state.metrics metrics_train["Epoch"].append(trainer.state.epoch) for key, value in mts.items(): metrics_train[key].append(value) # Update LR based on the loss on the training set if args.lr_dynamic: loss = mts["Pose Loss"] if affinity: loss += mts["Affinity Loss"] if flex: loss += mts["Flex Pose Loss"] torch_scheduler.step(loss) assert len(optimizer.param_groups) == 1 for oustream in outstreams: print( f" Learning rate: {optimizer.param_groups[0]['lr']}", file=oustream, ) if args.testfile is not None: @trainer.on(Events.EPOCH_COMPLETED(every=args.test_every)) def log_test_results(trainer): test_evaluator.run(test_loader) for outstream in outstreams: utils.log_print( test_evaluator.state.metrics, title="Test Results", epoch=trainer.state.epoch, stream=outstream, ) metrics_test["Epoch"].append(trainer.state.epoch) for key, value in test_evaluator.state.metrics.items(): metrics_test[key].append(value) # TODO: Add checkpoints as artifacts to MLflow # TODO: Save input parameters as well # TODO: Save best models (lowest validation loss) to_save = {"model": model, "optimizer": optimizer} # Requires no checkpoint in the output directory # Since checkpoints are not automatically removed when restarting, it would be # dangerous to run without requiring the directory to have no previous checkpoints checkpoint = Checkpoint( to_save, os.path.join(args.out_dir, args.checkpoint_dir), filename_prefix=args.checkpoint_prefix, n_saved=args.num_checkpoints, global_step_transform=lambda *_: trainer.state.epoch, ) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=args.checkpoint_every), checkpoint ) if args.progress_bar: pbar = ProgressBar() pbar.attach(trainer) trainer.run(train_loader, max_epochs=args.iterations) # Use log file name as prefix of output names log_root = os.path.splitext(args.log_file)[0] metrics_train_outfile = os.path.join(args.out_dir, f"{log_root}_metrics_train.csv") pd.DataFrame(metrics_train).to_csv( metrics_train_outfile, float_format="%.5f", index=False, ) mlflogger.log_artifact(metrics_train_outfile) if args.testfile is not None: metrics_test_outfile = os.path.join( args.out_dir, f"{log_root}_metrics_test.csv" ) pd.DataFrame(metrics_test).to_csv( metrics_test_outfile, float_format="%.5f", index=False, ) mlflogger.log_artifact(metrics_test_outfile) # Close log file and save as artifact logfile.close() mlflogger.log_artifact(logfilename)
if __name__ == "__main__": args = options() training(args)