import abc
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Final, Literal, Optional
import beartype
import equinox
import jax
import jax.numpy as jnp
import jaxtyping as jt
from jaxtyping import Array, Bool, Float, Int
import jraph
import optax.losses
from pytray import tree
import reax
from . import _tree, graph_ops, keys, typing, utils
from .. import base
if TYPE_CHECKING:
from tensorial import gcnn
__all__ = "PureLossFn", "GraphLoss", "WeightedLoss", "Loss"
# A pure loss function that doesn't know about graphs, just takes arrays and produces a loss array
PureLossFn = Callable[[jax.Array, jax.Array], jax.Array]
[docs]
class GraphLoss(equinox.Module):
_label: str
def __init__(self, label: str):
self._label = label
[docs]
def label(self) -> str:
"""Get a label for this loss function"""
return self._label
def __call__(
self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple = None
) -> jax.Array:
"""Return the scalar loss between predictions and targets"""
if targets is None:
targets = predictions
return self._call(predictions, targets)
@abc.abstractmethod
def _call(self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple) -> jax.Array:
"""Return the scalar loss between predictions and targets"""
[docs]
class Loss(GraphLoss):
"""Simple loss function that passes values from the graph to a function taking numerical values
such as optax losses
"""
_loss_fn: PureLossFn
_target_field: "gcnn.typing.TreePath"
_prediction_field: "gcnn.typing.TreePath"
_mask_field: "Optional[gcnn.typing.TreePath]"
_reduction: Literal["sum", "mean"]
@jt.jaxtyped(typechecker=beartype.beartype)
def __init__(
self,
loss_fn: str | PureLossFn,
targets: str,
predictions: str | None = None,
*,
reduction: Literal["sum", "mean"] = "mean",
label: str = None,
mask_field: str | None = None,
):
"""
Initializes the loss function wrapper with specified parameters.
This constructor sets up a loss function wrapper that can be used to compute
loss values based on target and prediction fields. It supports various loss
functions and provides options for reduction and masking.
Args:
loss_fn: The loss function to use, either as a string identifier or a
callable implementing the PureLossFn protocol.
targets: The path to the target field in the data structure.
predictions: The path to the prediction field in the data structure. If
None, the target field path will be used.
reduction: The reduction method to apply to the computed losses. Either
"sum" or "mean".
label: The label to use for the loss function. If None, the
prediction field path will be used as the label.
mask_field: The path to the mask field in the data structure. If None,
no masking will be applied.
Raises:
ValueError: If the reduction method is not "sum" or "mean".
TypeError: If the loss function is not a valid string or callable.
"""
self._loss_fn = _get_pure_loss_fn(loss_fn)
self._target_field: Final[typing.TreePath] = utils.path_from_str(targets)
self._prediction_field: Final[typing.TreePath] = utils.path_from_str(predictions or targets)
if mask_field is not None:
self._mask_field = utils.path_from_str(mask_field)
else:
self._mask_field = None
self._reduction = reduction
super().__init__(label or utils.path_to_str(self._prediction_field))
def _call(self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple) -> jax.Array:
predictions_dict = predictions._asdict()
pred_values = base.as_array(tree.get_by_path(predictions_dict, self._prediction_field))
target_values = base.as_array(tree.get_by_path(targets._asdict(), self._target_field))
loss = self._loss_fn(pred_values, target_values)
# If there is a mask in the graph, then use it by default
mask = _tree.get_mask(targets, self._target_field)
if mask is not None:
mask = reax.metrics.utils.prepare_mask(loss, mask)
# Now, check for the presence of a user-defined mask
if self._mask_field:
user_mask = base.as_array(tree.get_by_path(targets._asdict(), self._mask_field))
user_mask = reax.metrics.utils.prepare_mask(loss, user_mask)
if mask is None:
mask = user_mask
else:
mask = mask & user_mask
graph_mask: Bool[jax.Array, "n_graph ..."] | None = targets.globals.get(keys.MASK)
root: str = self._target_field[0]
if root in ("nodes", "edges"):
segments: Int[Array, "n_graph"] = targets.n_node if root == "nodes" else targets.n_edge
loss: Float[Array, "n_graph ..."] = graph_ops.segment_reduce(
loss, segments, reduction=self._reduction, mask=mask, segment_mask=graph_mask
)
loss = graph_ops.segment_reduce(
loss, jnp.array([loss.shape[0]]), reduction=self._reduction, mask=graph_mask
)
loss = jnp.mean(loss)
return loss
[docs]
class WeightedLoss(GraphLoss):
_weights: tuple[float, ...]
_loss_fns: tuple[GraphLoss, ...]
@jt.jaxtyped(typechecker=beartype.beartype)
def __init__(
self,
loss_fns: Sequence[GraphLoss],
weights: Sequence[float] | None = None,
):
"""
A weighted combination of multiple graph loss functions.
This class combines multiple graph loss functions with specified weights to
create a composite loss function. The weights determine the contribution of
each individual loss function to the final combined loss.
Args:
loss_fns: Sequence of graph loss functions to combine.
weights: Sequence of weights for each loss function. If None, all
weights are set to 1.0.
Raises:
ValueError: If any element in loss_fns is not a subclass of GraphLoss,
or if the number of weights does not match the number of loss
functions.
"""
super().__init__("weighted loss")
for loss in loss_fns:
if not isinstance(loss, GraphLoss):
raise ValueError(
f"loss_fns must all be subclasses of GraphLoss, got {type(loss).__name__}"
)
if weights is None:
weights = (1.0,) * len(loss_fns)
else:
if len(weights) != len(loss_fns):
raise ValueError(
f"the number of weights and loss functions must be equal, got {len(weights)} "
f"and {len(loss_fns)}"
)
self._weights = tuple(
weights
) # We have to use a tuple here, otherwise jax will treat this as a dynamic type
self._loss_fns = tuple(loss_fns)
@property
def weights(self):
return jax.lax.stop_gradient(jnp.array(self._weights))
def _call(self, predictions: jraph.GraphsTuple, targets: jraph.GraphsTuple) -> jax.Array:
# Calculate the loss for each function
losses = jnp.array(list(map(lambda loss_fn: loss_fn(predictions, targets), self._loss_fns)))
return jnp.dot(self.weights, losses)
[docs]
def loss_with_contributions(
self, predictions: jraph.GraphsTuple, target: jraph.GraphsTuple
) -> tuple[jax.Array, dict[str, float]]:
# Calculate the loss for each function
losses = jax.array(list(map(lambda loss_fn: loss_fn(predictions, target), self._loss_fns)))
# Group the contributions into a dictionary keyed by the label
contribs = dict(zip(list(map(GraphLoss.label, self._loss_fns)), losses))
return jnp.dot(self.weights, losses), contribs
def _get_pure_loss_fn(loss_fn: str | PureLossFn) -> PureLossFn:
if isinstance(loss_fn, str):
return getattr(optax.losses, loss_fn)
if isinstance(loss_fn, Callable):
return loss_fn
raise ValueError(f"Unknown loss function type: {type(loss_fn).__name__}")