Source code for tensorial.nn
from collections.abc import Sequence
import functools
from flax import linen
[docs]
class Sequential(linen.Module):
"""Applies a sequential chain of modules just like :class:`flax.linen.Sequential` _except_ that
flax's version will expand any tuples that it receives when calling the next layer. This
doesn't play nice with types that subclass `tuple`, for example, :class:`jraph.GraphsTuple`,
because the layers expect to get a `GraphsTuple`, not the individual values that make it up.
Our behaviour is the same as :class:`flax.linen.Sequential` if we get a `tuple`, but any
subclasses thereof are kept intact when calling the next layer.
"""
layers: Sequence[linen.Module | functools.partial]
[docs]
def setup(self) -> None:
# pylint: disable=attribute-defined-outside-init
self._layers: list[linen.Module] = _layers(self.layers)
def __post_init__(self):
if not isinstance(self.layers, Sequence):
raise ValueError(f"'layers' must be a sequence, got '{type(self.layers).__name__}'.")
if not self.layers:
raise ValueError(f"Empty Sequential module {self.name}.")
super().__post_init__()
@linen.compact
def __call__(self, *args, **kwargs):
outputs = self._layers[0](*args, **kwargs)
for layer in self._layers[1:]:
if isinstance(outputs, dict):
outputs = layer(**outputs)
elif type(outputs) is tuple: # pylint: disable=unidiomatic-typecheck
outputs = layer(**outputs)
else:
outputs = layer(outputs)
return outputs
def _layers(layers: Sequence[linen.Module | functools.partial]) -> list[linen.Module]:
"""Create the model from the configuration object"""
new_layers = []
for layer in layers:
if isinstance(layer, functools.partial):
# We've reached a module that is partly constructed. This indicates that it's a
# module that wraps a function i.e. f(g(x)), typically because it needs access to
# g(x) (for example to calculate gradients). So, we build what we've found so far,
# and pass it to the module
if len(new_layers) == 0:
raise ValueError(
f"Got a partial module, but have no previous modules to pass to it: {layer}"
)
if len(new_layers) == 1:
nested = new_layers[0]
else:
nested = Sequential(new_layers)
layer = layer(nested)
if not isinstance(layer, linen.Module):
raise ValueError(
f"Calling partial module {type(layer).__name__}() did not resolve to a "
f"linen.Module instance"
)
new_layers = [layer]
else:
new_layers.append(layer)
if len(new_layers) == 1:
# Special case to avoid needlessly wrapping a single module
return [new_layers[0]]
return new_layers