Source code for tensorial.gcnn.losses

import abc
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Final, Literal, Optional

import beartype
import equinox
import jax
import jax.numpy as jnp
import jaxtyping as jt
import jraph
import optax.losses
from pytray import tree
import reax

from . import _tree, graph_ops, keys, typing, utils
from .. import base

if TYPE_CHECKING:
    from tensorial import gcnn

__all__ = "PureLossFn", "GraphLoss", "WeightedLoss", "Loss"

# A pure loss function that doesn't know about graphs, just takes arrays and produces a loss array
PureLossFn = Callable[[jax.Array, jax.Array], jax.Array]


[docs] class GraphLoss(equinox.Module): _label: str def __init__(self, label: str): self._label = label
[docs] def label(self) -> str: """Get a label for this loss function""" return self._label
def __call__( self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple = None ) -> jax.Array: """Return the scalar loss between predictions and targets""" if targets is None: targets = predictions return self._call(predictions, targets) @abc.abstractmethod def _call(self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple) -> jax.Array: """Return the scalar loss between predictions and targets"""
[docs] class Loss(GraphLoss): """Simple loss function that passes values from the graph to a function taking numerical values such as optax losses """ _loss_fn: PureLossFn _target_field: "gcnn.typing.TreePath" _prediction_field: "gcnn.typing.TreePath" _mask_field: "Optional[gcnn.typing.TreePath]" _reduction: Literal["sum", "mean"] @jt.jaxtyped(typechecker=beartype.beartype) def __init__( self, loss_fn: str | PureLossFn, targets: str, predictions: str | None = None, *, reduction: Literal["sum", "mean"] = "mean", label: str = None, mask_field: str | None = None, ): """ Initializes the loss function wrapper with specified parameters. This constructor sets up a loss function wrapper that can be used to compute loss values based on target and prediction fields. It supports various loss functions and provides options for reduction and masking. Args: loss_fn: The loss function to use, either as a string identifier or a callable implementing the PureLossFn protocol. targets: The path to the target field in the data structure. predictions: The path to the prediction field in the data structure. If None, the target field path will be used. reduction: The reduction method to apply to the computed losses. Either "sum" or "mean". label: The label to use for the loss function. If None, the prediction field path will be used as the label. mask_field: The path to the mask field in the data structure. If None, no masking will be applied. Raises: ValueError: If the reduction method is not "sum" or "mean". TypeError: If the loss function is not a valid string or callable. """ self._loss_fn = _get_pure_loss_fn(loss_fn) self._target_field: Final[typing.TreePath] = utils.path_from_str(targets) self._prediction_field: Final[typing.TreePath] = utils.path_from_str(predictions or targets) if mask_field is not None: self._mask_field = utils.path_from_str(mask_field) else: self._mask_field = None self._reduction = reduction super().__init__(label or utils.path_to_str(self._prediction_field)) def _call(self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple) -> jax.Array: predictions_dict = predictions._asdict() pred_values = base.as_array(tree.get_by_path(predictions_dict, self._prediction_field)) target_values = base.as_array(tree.get_by_path(targets._asdict(), self._target_field)) loss = self._loss_fn(pred_values, target_values) # If there is a mask in the graph, then use it by default mask = _tree.get_mask(targets, self._target_field) if mask is not None: mask = reax.metrics.utils.prepare_mask(loss, mask) # Now, check for the presence of a user-defined mask if self._mask_field: user_mask = base.as_array(tree.get_by_path(targets._asdict(), self._mask_field)) user_mask = reax.metrics.utils.prepare_mask(loss, user_mask) if mask is None: mask = user_mask else: mask = mask & user_mask graph_mask: jt.Bool[jax.Array, "n_graph ..."] | None = targets.globals.get(keys.MASK) root: str = self._target_field[0] if root in ("nodes", "edges"): segments: jt.Int[jax.Array, "n_graph"] = ( targets.n_node if root == "nodes" else targets.n_edge ) loss: jt.Float[jax.Array, "n_graph ..."] = graph_ops.segment_reduce( loss, segments, reduction=self._reduction, mask=mask, segment_mask=graph_mask ) loss = graph_ops.segment_reduce( loss, jnp.array([loss.shape[0]]), reduction=self._reduction, mask=graph_mask ) loss = jnp.mean(loss) return loss
[docs] class WeightedLoss(GraphLoss): _weights: tuple[float, ...] _loss_fns: tuple[GraphLoss, ...] @jt.jaxtyped(typechecker=beartype.beartype) def __init__( self, loss_fns: Sequence[GraphLoss], weights: Sequence[float] | None = None, ): """ A weighted combination of multiple graph loss functions. This class combines multiple graph loss functions with specified weights to create a composite loss function. The weights determine the contribution of each individual loss function to the final combined loss. Args: loss_fns: Sequence of graph loss functions to combine. weights: Sequence of weights for each loss function. If None, all weights are set to 1.0. Raises: ValueError: If any element in loss_fns is not a subclass of GraphLoss, or if the number of weights does not match the number of loss functions. """ super().__init__("weighted loss") for loss in loss_fns: if not isinstance(loss, GraphLoss): raise ValueError( f"loss_fns must all be subclasses of GraphLoss, got {type(loss).__name__}" ) if weights is None: weights = (1.0,) * len(loss_fns) else: if len(weights) != len(loss_fns): raise ValueError( f"the number of weights and loss functions must be equal, got {len(weights)} " f"and {len(loss_fns)}" ) self._weights = tuple( weights ) # We have to use a tuple here, otherwise jax will treat this as a dynamic type self._loss_fns = tuple(loss_fns) @property def weights(self): return jax.lax.stop_gradient(jnp.array(self._weights)) def _call(self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple) -> jax.Array: # Calculate the loss for each function losses = jnp.array(list(map(lambda loss_fn: loss_fn(predictions, targets), self._loss_fns))) return jnp.dot(self.weights, losses)
[docs] def loss_with_contributions( self, predictions: jraph.GraphsTuple, target: jraph.GraphsTuple ) -> tuple[jax.Array, dict[str, float]]: # Calculate the loss for each function losses = jax.array(list(map(lambda loss_fn: loss_fn(predictions, target), self._loss_fns))) # Group the contributions into a dictionary keyed by the label contribs = dict(zip(list(map(GraphLoss.label, self._loss_fns)), losses)) return jnp.dot(self.weights, losses), contribs
def _get_pure_loss_fn(loss_fn: str | PureLossFn) -> PureLossFn: if isinstance(loss_fn, str): return getattr(optax.losses, loss_fn) if isinstance(loss_fn, Callable): return loss_fn raise ValueError(f"Unknown loss function type: {type(loss_fn).__name__}")