Source code for gninatorch.dataloaders

from typing import Optional

import torch


[docs]class GriddedExamplesLoader: """ Load example and compute atomic density on a grid. Parameters ---------- example_provider : :class:`molgrid.ExampleProvider` :package:`molgrid` example provider grid_maker : :class:`molgrid.GridMaker` :package:`molgrid` grid maker label_pos: int Ligand pose annotation label position affinity_pos: Optional[int] Affinity annotation position flexlabel_pos: Optional[int] Receptor (side chains) pose annotation label position random_translation : float Random translation applied to each example on each cartesian axis random_rotation : bool Uniform random rotation applied to each example device : torch.device Device grid_only: bool If True, return only the grid, otherwise return grid and labels Notes ----- The batch size is defined in the :code:`example_provider` as :code:`default_batch_size`. The number of batches actually depend on the :code:`molgrid.IterationScheme` used, also defined in the :code:`example_provider`. If :code:`molgrid.IterationScheme.SmallEpoch` is used, examples are seen at most once. If :code:`molgrid.IterationScheme.LargeEpoch` is used, examples are seen at least once. The last batch is not padded with examples of the next epoch, in contrast with :code:`molgrid.ExampleProvider` default behaviour. """ def __init__( self, example_provider, grid_maker, label_pos: int = 0, affinity_pos: Optional[int] = None, flexlabel_pos: Optional[int] = None, random_translation: float = 0.0, random_rotation: bool = False, device: torch.device = torch.device("cpu"), grids_only: bool = False, ): # Check that example provider is populated assert example_provider.size() > 0 self.example_provider = example_provider self.grid_maker = grid_maker self.label_pos = label_pos self.affinity_pos = affinity_pos self.flexlabel_pos = flexlabel_pos self.random_translation = random_translation self.random_rotation = random_rotation self.device = device self.grids_only = grids_only # Total number of examples in file # This is not necessarily the same as the number of examples seen in an epoch # The number of examples in an epoch depends on the example provider settings self.num_examples_tot = self.example_provider.size() self.num_labels = self.example_provider.num_labels() self.num_types = self.example_provider.num_types() self.dims = grid_maker.grid_dimensions(self.num_types) example_provider_settings = example_provider.settings() self.iteration_scheme = str(example_provider_settings.iteration_scheme) if self.iteration_scheme == "SmallEpoch": self.num_examples_per_epoch = example_provider.small_epoch_size() elif self.iteration_scheme == "LargeEpoch": self.num_examples_per_epoch = example_provider.large_epoch_size() else: raise ValueError(f"Unknown iteration scheme {self.iteration_scheme}") self.batch_size = example_provider_settings.default_batch_size if example_provider_settings.balanced and self.batch_size == 1: raise ValueError("Balanced batches incompatible with batch size 1.") self.num_batches = (self.num_examples_per_epoch) // self.batch_size self.last_batch_size = self.num_examples_per_epoch % self.batch_size if self.last_batch_size != 0: self.num_batches += 1 self.batch_idx = 0 self.last_epoch = False def __next__(self): """ Get next batch of gridded examples and corresponding labels. """ if self.last_epoch: raise StopIteration if self.batch_idx < self.num_batches - 1: # All batches except the last batch = self.example_provider.next_batch(self.batch_size) self.batch_idx += 1 else: # Treat last batch differently if self.last_batch_size != 0: # Last batch has a different size # This avoids padding with examples from the next epoch batch = self.example_provider.next_batch(self.last_batch_size) else: # Last batch has the same size as previous batches batch = self.example_provider.next_batch(self.batch_size) # Last epoch, raise StopIteration at the next iteration attempt # Reset index self.last_epoch = True batch_size = len(batch) # Compute grids from examples grids = torch.zeros((batch_size, *self.dims), device=self.device) self.grid_maker.forward( batch, grids, random_translation=self.random_translation, random_rotation=self.random_rotation, ) if not self.grids_only: # Ligand pose labels # Convert labels to integers; libmolgrid only supports float labels labels = torch.zeros((batch_size,), device=self.device) batch.extract_label(self.label_pos, labels) labels = labels.long() # Convert labels to integer # Affinity values if self.affinity_pos is not None: affinities = torch.zeros((batch_size,), device=self.device) batch.extract_label(self.affinity_pos, affinities) # Flexible side chains pose labels # Convert labels to integers; libmolgrid only supports float labels if self.flexlabel_pos is not None: flexlabels = torch.zeros((batch_size,), device=self.device) batch.extract_label(self.flexlabel_pos, flexlabels) flexlabels = flexlabels.long() # Return appropriate tensors depending on labels extracted if self.grids_only: return grids elif self.affinity_pos is None and self.flexlabel_pos is None: # Return grids and labels return grids, labels elif self.affinity_pos is not None and self.flexlabel_pos is None: # Return grids, labels and affinities return grids, labels, affinities elif self.affinity_pos is None and self.flexlabel_pos is not None: # Return grids, labels and flexlabels return grids, labels, flexlabels elif self.affinity_pos is not None and self.flexlabel_pos is not None: # Return grids, labels, affinities and flexlabels return grids, labels, affinities, flexlabels else: # This should never occur... raise NotImplementedError def __iter__(self): """ Return iterator for a new epoch. Notes ----- At every epoch the iterator is reset. See https://pytorch.org/ignite/concepts.html#engine for details about the internals of :code:`ignite.ending.engine.Engine` and the associated abstraction loop. """ # Reset iterator self.last_epoch = False self.batch_idx = 0 # TODO: Check that we actually need to reset the iterator self.example_provider.reset() return self