Source code for tensorial.nn_utils
from collections.abc import Callable
import jax
import jax.nn
import jax.numpy as jnp
import jaxtyping as jt
ActivationFunction = Callable[[jax.Array], jax.Array]
[docs]
def get_jaxnn_activation(func: ActivationFunction) -> ActivationFunction:
"""Returns the activation function with `name` form the jax.nn module
Args:
func: the name of the function (as used in ``jax.nn``)
Returns:
the activation function
"""
if isinstance(func, Callable):
return func
try:
return getattr(jax.nn, func)
except AttributeError:
raise ValueError(f"Activation function '{func}' not found in jax.nn") from None
[docs]
def prepare_mask(
mask: jt.Bool[jax.Array, "n_elements"], array: jt.Float[jax.Array, "..."]
) -> jt.Float[jax.Array, "n_elements ..."]:
"""Prepare a mask for use with jnp.where(mask, array, ...). This needs to be done to make sure
the mask is of the right shape to be compatible with such an operation. The other alternative
is
``jnp.where(mask, array.T, ...).T``
but this sometimes leads to creating a copy when doing one or both of the transposes. I'm not
sure why, but this approach seems to avoid the problem.
Args:
mask: the mask to prepare
array: the array the mask will be applied to
Returns:
the prepared mask, typically this is just padded with extra
dimensions (or reduced)
"""
return mask.reshape(-1, *(1,) * len(array.shape[1:]))
[docs]
def vwhere(values: jax.Array, types: jax.Array) -> jax.Array:
vectorized = jax.vmap(lambda num: jnp.argwhere(num == types, size=1)[0])
return vectorized(values)[:, 0]