Introduction#

1. Install REAX#

pip install reax

2. Define a ReaxModule#

A reax.Module keeps track of your model parameter and give you a place to put the code for the various steps in your training loop (training_step, validation_step, etc).

[1]:
import os
from functools import partial
from flax import linen
import jax
import jax.numpy as jnp
import optax
import reax
from reax import demos


class Autoencoder(linen.Module):
    def setup(self):
        self.encoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(3)])
        self.decoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(28 * 28)])

    def __call__(self, x):
        z = self.encoder(x)
        return self.decoder(z)

    def encode(self, x):
        return self.encoder(x)


class ReaxAutoEncoder(reax.Module):
    def __init__(self):
        super().__init__()
        self.ae = Autoencoder()
        self._encode = partial(self.ae.apply, method="encode")

    def configure_model(self, stage: reax.Stage, batch, /):
        if self.parameters() is None:
            inputs, _ = self.prepare_batch(batch)
            params = self.ae.init(self.rngs(), inputs[0])
            self.set_parameters(params)

    def training_step(self, batch, batch_idx):
        x, _ = self.prepare_batch(batch)
        loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(self.parameters(), x, self.ae)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss, grads

    @staticmethod
    @partial(jax.jit, static_argnums=2)
    def loss_fn(params, x_batch, model):
        predictions = jax.vmap(model.apply, in_axes=(None, 0))(params, x_batch)
        return optax.losses.squared_error(predictions, x_batch).mean()

    def encode(self, x_batch):
        x_batch, _ = self.prepare_batch((x_batch, None))
        return jax.vmap(self._encode, in_axes=(None, 0))(self.parameters(), x_batch)

    def configure_optimizers(self):
        opt = optax.adam(learning_rate=1e-3)
        state = opt.init(self.parameters())
        return opt, state

    @staticmethod
    def prepare_batch(batch):
        x, y = batch
        return x.reshape(x.shape[0], -1), y


autoencoder = ReaxAutoEncoder()

3. Define a dataset#

REAX supports any iterable (numpy arrays, lists etc) for the train/val/test/predict datasets.

[2]:
# Setup the data
dataset = demos.mnist.MnistDataset(download=True)
data_loader = reax.ReaxDataLoader(dataset)
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MnistDataset
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MnistDataset
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MnistDataset
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MnistDataset

4. Train the mode#

The REAX Trainer takes the module and dataset and combines them in a training loop, automating away most of the boilerplate.

[3]:
trainer = reax.Trainer()
trainer.fit(autoencoder, data_loader, limit_train_batches=100, max_epochs=1);
Error loading REAX plugin from entrypoint 'EntryPoint(name='native', value='tensorial._provides:get_batch_sizers', group='reax.plugins.batch_sizers')':
unsupported operand type(s) for |: 'types.UnionType' and 'jaxlib._jax.Device'

5. Use the model#

[4]:
checkpoint = "./reax_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
ckpt = trainer.checkpointing.load(checkpoint)
autoencoder.set_parameters(ckpt["parameters"])

# embed 4 fake images!
fake_image_batch = jax.random.uniform(trainer.rngs(), shape=(4, 28, 28))
fake_image_batch = trainer.engine.to_device(fake_image_batch)
embeddings = autoencoder.encode(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)
⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
Predictions (4 image embeddings):
 [[ 0.19634824  0.781361   -0.21130426]
 [ 0.37009647  1.0074227  -0.4198979 ]
 [ 0.38216683  0.8424615  -0.1129754 ]
 [ 0.12117399  1.0413505  -0.25118512]]
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡