Source code for tensorial.gcnn.losses

import abc
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Final, Literal, Optional

import beartype
import equinox
import jax
import jax.numpy as jnp
import jaxtyping as jt
from jaxtyping import Array, Bool, Float, Int
import jraph
import optax.losses
from pytray import tree
import reax

from . import _tree, graph_ops, keys, typing, utils
from .. import base

if TYPE_CHECKING:
    from tensorial import gcnn

__all__ = "PureLossFn", "GraphLoss", "WeightedLoss", "Loss"

# A pure loss function that doesn't know about graphs, just takes arrays and produces a loss array
PureLossFn = Callable[[jax.Array, jax.Array], jax.Array]


[docs] class GraphLoss(equinox.Module): _label: str def __init__(self, label: str): self._label = label
[docs] def label(self) -> str: """Get a label for this loss function""" return self._label
def __call__( self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple = None ) -> jax.Array: """Return the scalar loss between predictions and targets""" if targets is None: targets = predictions return self._call(predictions, targets) @abc.abstractmethod def _call(self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple) -> jax.Array: """Return the scalar loss between predictions and targets"""
[docs] class Loss(GraphLoss): """Simple loss function that passes values from the graph to a function taking numerical values such as optax losses """ _loss_fn: PureLossFn _target_field: "gcnn.typing.TreePath" _prediction_field: "gcnn.typing.TreePath" _mask_field: "Optional[gcnn.typing.TreePath]" _reduction: Literal["sum", "mean"] @jt.jaxtyped(typechecker=beartype.beartype) def __init__( self, loss_fn: str | PureLossFn, targets: str, predictions: str | None = None, *, reduction: Literal["sum", "mean"] = "mean", label: str = None, mask_field: str | None = None, ): """ Initializes the loss function wrapper with specified parameters. This constructor sets up a loss function wrapper that can be used to compute loss values based on target and prediction fields. It supports various loss functions and provides options for reduction and masking. Args: loss_fn: The loss function to use, either as a string identifier or a callable implementing the PureLossFn protocol. targets: The path to the target field in the data structure. predictions: The path to the prediction field in the data structure. If None, the target field path will be used. reduction: The reduction method to apply to the computed losses. Either "sum" or "mean". label: The label to use for the loss function. If None, the prediction field path will be used as the label. mask_field: The path to the mask field in the data structure. If None, no masking will be applied. Raises: ValueError: If the reduction method is not "sum" or "mean". TypeError: If the loss function is not a valid string or callable. """ self._loss_fn = _get_pure_loss_fn(loss_fn) self._target_field: Final[typing.TreePath] = utils.path_from_str(targets) self._prediction_field: Final[typing.TreePath] = utils.path_from_str(predictions or targets) if mask_field is not None: self._mask_field = utils.path_from_str(mask_field) else: self._mask_field = None self._reduction = reduction super().__init__(label or utils.path_to_str(self._prediction_field)) def _call(self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple) -> jax.Array: predictions_dict = predictions._asdict() pred_values = base.as_array(tree.get_by_path(predictions_dict, self._prediction_field)) target_values = base.as_array(tree.get_by_path(targets._asdict(), self._target_field)) loss = self._loss_fn(pred_values, target_values) # If there is a mask in the graph, then use it by default mask = _tree.get_mask(targets, self._target_field) if mask is not None: mask = reax.metrics.utils.prepare_mask(loss, mask) # Now, check for the presence of a user-defined mask if self._mask_field: user_mask = base.as_array(tree.get_by_path(targets._asdict(), self._mask_field)) user_mask = reax.metrics.utils.prepare_mask(loss, user_mask) if mask is None: mask = user_mask else: mask = mask & user_mask graph_mask: Bool[jax.Array, "n_graph ..."] | None = targets.globals.get(keys.MASK) root: str = self._target_field[0] if root in ("nodes", "edges"): segments: Int[Array, "n_graph"] = targets.n_node if root == "nodes" else targets.n_edge loss: Float[Array, "n_graph ..."] = graph_ops.segment_reduce( loss, segments, reduction=self._reduction, mask=mask, segment_mask=graph_mask ) loss = graph_ops.segment_reduce( loss, jnp.array([loss.shape[0]]), reduction=self._reduction, mask=graph_mask ) loss = jnp.mean(loss) return loss
[docs] class WeightedLoss(GraphLoss): _weights: tuple[float, ...] _loss_fns: tuple[GraphLoss, ...] @jt.jaxtyped(typechecker=beartype.beartype) def __init__( self, loss_fns: Sequence[GraphLoss], weights: Sequence[float] | None = None, ): """ A weighted combination of multiple graph loss functions. This class combines multiple graph loss functions with specified weights to create a composite loss function. The weights determine the contribution of each individual loss function to the final combined loss. Args: loss_fns: Sequence of graph loss functions to combine. weights: Sequence of weights for each loss function. If None, all weights are set to 1.0. Raises: ValueError: If any element in loss_fns is not a subclass of GraphLoss, or if the number of weights does not match the number of loss functions. """ super().__init__("weighted loss") for loss in loss_fns: if not isinstance(loss, GraphLoss): raise ValueError( f"loss_fns must all be subclasses of GraphLoss, got {type(loss).__name__}" ) if weights is None: weights = (1.0,) * len(loss_fns) else: if len(weights) != len(loss_fns): raise ValueError( f"the number of weights and loss functions must be equal, got {len(weights)} " f"and {len(loss_fns)}" ) self._weights = tuple( weights ) # We have to use a tuple here, otherwise jax will treat this as a dynamic type self._loss_fns = tuple(loss_fns) @property def weights(self): return jax.lax.stop_gradient(jnp.array(self._weights)) def _call(self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple) -> jax.Array: # Calculate the loss for each function losses = jnp.array(list(map(lambda loss_fn: loss_fn(predictions, targets), self._loss_fns))) return jnp.dot(self.weights, losses)
[docs] def loss_with_contributions( self, predictions: jraph.GraphsTuple, target: jraph.GraphsTuple ) -> tuple[jax.Array, dict[str, float]]: # Calculate the loss for each function losses = jax.array(list(map(lambda loss_fn: loss_fn(predictions, target), self._loss_fns))) # Group the contributions into a dictionary keyed by the label contribs = dict(zip(list(map(GraphLoss.label, self._loss_fns)), losses)) return jnp.dot(self.weights, losses), contribs
def _get_pure_loss_fn(loss_fn: str | PureLossFn) -> PureLossFn: if isinstance(loss_fn, str): return getattr(optax.losses, loss_fn) if isinstance(loss_fn, Callable): return loss_fn raise ValueError(f"Unknown loss function type: {type(loss_fn).__name__}")