Skip to content

Quaxed¤

pre-quaxified JAX functions.

quax enables JAX + multiple dispatch + custom array-ish objects. quaxed means you don't have to wrap every function in quax.quaxify wrappers every time.

Installation¤

pip install quaxed

Getting started¤

Import whatever library you need as a drop-in replacement for its JAX counterpart.

To see the API check out the quaxed in the left bar.

import quaxed.numpy as jnp

x = jnp.linspace(0.0, 1.0, num=3)

print(jnp.cos(x))
# [1.         0.87758255 0.5403023 ]

The advantage of quaxed over plain JAX is that every function is quaxify'd and will work with properly formulated array-ish objects.

For this example we use unxt's Quantity for unitful calculations.

from unxt import Quantity

x = Quantity(jnp.linspace(0.0, 1.0, num=3), "deg")

print(jnp.cos(x))
# Quantity['dimensionless'](Array([1.       , 0.9999619, 0.9998477], dtype=float32), unit='')

See also: other libraries in the Quax ecosystem¤

Quax: the base library.

unxt: Units and Quantities in Jax.

coordinax: Vector representations (built on unxt).

galax: Galactic dynamics in Jax (built on unxt and coordinax).