Package tjax

This library implements a variety of tools for the differential programming library JAX.

Expand source code
"""
This library implements a variety of tools for the differential programming library
[JAX](https://github.com/google/jax).
"""
from .annotations import *
from .color_stub import *
from .dataclass import *
from .display import *
from .dtypes import *
from .generator import *
from .graph import *
from .log_cotangent import *
from .partial import *
from .pytree_like import *
from .shims import *
from .testing import *
from .tools import *

__pdoc__ = {}
__pdoc__['real_dtype'] = False
__pdoc__['complex_dtype'] = False
__pdoc__['PyTreeLike'] = False
__pdoc__['Field'] = False
__pdoc__['InitVar'] = False
__pdoc__['FrozenInstanceError'] = False
from .dataclass import document_dataclass

document_dataclass(__pdoc__, 'Generator')
document_dataclass(__pdoc__, 'Partial')
document_dataclass(__pdoc__, 'LogCotangent')
del document_dataclass


__all__ = list(locals())

Sub-modules

tjax.annotations
tjax.color_stub
tjax.display
tjax.dtypes
tjax.generator
tjax.graph
tjax.log_cotangent
tjax.partial
tjax.pytree_like
tjax.shims
tjax.testing
tjax.tools

Functions

def assert_jax_allclose(actual: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, ForwardRef('PyTreeLike'), Tuple[ForwardRef('PyTree'), ...], List[ForwardRef('PyTree')], Dict[Hashable, ForwardRef('PyTree')], NoneType], desired: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, ForwardRef('PyTreeLike'), Tuple[ForwardRef('PyTree'), ...], List[ForwardRef('PyTree')], Dict[Hashable, ForwardRef('PyTree')], NoneType], original_name: Union[str, NoneType] = None, original_value: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, ForwardRef('PyTreeLike'), Tuple[ForwardRef('PyTree'), ...], List[ForwardRef('PyTree')], Dict[Hashable, ForwardRef('PyTree')], NoneType] = None, *, rtol: Union[float, NoneType] = None, atol: Union[float, NoneType] = None)

Asserts that every tensor in an actual pytree matches the corresponding tensor in a desired pytree. If the assertion fails, a passing test string is printed::

from tjax import assert_jax_allclose, dataclass, Tensor

@dataclass
class A:
    x: Tensor
    y: Tensor

@dataclass
class B:
    z: A

original = B(A(1.2, 3.4))
desired = B(A(3.0, 4.0))
actual = B(A(1.2, 5.2))

assert_jax_allclose(actual, desired, 'original', original)

This prints::

JAX trees don't match.  Actual:
B
    z=A
        x=3.0
        y=4.0

Desired:
B
    z=A
        x=1.2
        y=5.2

Test string:
original.replace(z=original.z.replace(x=3.0, y=4.0))

The test string can then be pasted.

Args

actual
The actual value.
desired
The desired value.
original_name
The variable name that contains the original value.
original_value
The original value. This is usually a pytree like a dataclass that has the same type as actual and desired, but contains different values.
rtol
The relative tolerance of the comparisons in the assertion.
atol
The absolute tolerance of the comparisons in the assertion.
Expand source code
def assert_jax_allclose(actual: PyTree,
                        desired: PyTree,
                        original_name: Optional[str] = None,
                        original_value: Optional[PyTree] = None,
                        *,
                        rtol: Optional[float] = None,
                        atol: Optional[float] = None) -> None:
    """
    Asserts that every tensor in an actual pytree matches the corresponding tensor in a desired
    pytree.  If the assertion fails, a passing test string is printed::

    ```python
    from tjax import assert_jax_allclose, dataclass, Tensor

    @dataclass
    class A:
        x: Tensor
        y: Tensor

    @dataclass
    class B:
        z: A

    original = B(A(1.2, 3.4))
    desired = B(A(3.0, 4.0))
    actual = B(A(1.2, 5.2))

    assert_jax_allclose(actual, desired, 'original', original)
    ```
    This prints::
    ```
    JAX trees don't match.  Actual:
    B
        z=A
            x=3.0
            y=4.0

    Desired:
    B
        z=A
            x=1.2
            y=5.2

    Test string:
    original.replace(z=original.z.replace(x=3.0, y=4.0))
    ```
    The test string can then be pasted.

    Args:
        actual: The actual value.
        desired: The desired value.
        original_name: The variable name that contains the original value.
        original_value: The original value.  This is usually a pytree like a dataclass that has the
            same type as actual and desired, but contains different values.
        rtol: The relative tolerance of the comparisons in the assertion.
        atol: The absolute tolerance of the comparisons in the assertion.
    """
    if rtol is None:
        rtol = default_rtol
    if atol is None:
        atol = default_atol

    try:
        tree_multimap(partial(np.testing.assert_allclose, rtol=rtol, atol=atol), actual, desired)
    except Exception:
        print("JAX trees don't match.  Actual:")
        print(actual)
        print("Desired:")
        print(desired)
        if original_name is not None and original_value is not None:
            print("Test string:")
            print(get_test_string(original_name, actual, original_value, rtol, atol))
        raise
def dataclass(clz: Type[~T]) ‑> Type[~T]

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 annotations to determine fields. Default values for fields are provided using assignment. To mark fields as JAX static fields rather than JAX pytree fields, use the field() function.

For example::

from __future__ import annotations

from typing import ClassVar

from tjax import dataclass, field, Tensor
from jax import numpy as jnp
from jax import grad

@dataclass
class LearnedParameter:
    weight: Tensor
    constrain_positive: bool = field(pytree_like=False)
    minimum_positive_weight: ClassVar[Tensor] = 1e-6

    def trained(self,
                self_bar: LearnedParameter,
                learning_rate: float) -> LearnedParameter:
        weight_bar = self_bar.weight
        weight = self.weight - weight_bar * learning_rate
        if self.constrain_positive:
            weight = jnp.maximum(weight, self.minimum_positive_weight)
        return LearnedParameter(weight=weight,
                                constrain_positive=self.constrain_positive)

def loss(w: LearnedParameter) -> float:
    return jnp.square(w.weight - 3.3)

w = LearnedParameter(2.0, True)
w_bar = grad(loss)(w)
new_w = w.trained(w_bar, 1e-4)

dataclass() includes a convenient replace method::

w.replace(weight=3.4)

Since this dataclass is a pytree, all of JAX's functions that accept pytrees work with it, including iteration, differentiation, and jax.tree_util functions.

Another benefit is the display of dataclasses. print(new_w) gives::

LearnedParameter
    weight=Jax Array ()
            2.0003
    constrain_positive=True
Expand source code
def dataclass(clz: Type[T]) -> Type[T]:
    """
    Returns the same class as was passed in, with dunder methods added based on the fields defined
    in the class.

    Examines PEP 526 annotations to determine fields.  Default values for fields are provided using
    assignment.  To mark fields as JAX static fields rather than JAX pytree fields, use the `field`
    function.

    For example::
    ```python
    from __future__ import annotations

    from typing import ClassVar

    from tjax import dataclass, field, Tensor
    from jax import numpy as jnp
    from jax import grad

    @dataclass
    class LearnedParameter:
        weight: Tensor
        constrain_positive: bool = field(pytree_like=False)
        minimum_positive_weight: ClassVar[Tensor] = 1e-6

        def trained(self,
                    self_bar: LearnedParameter,
                    learning_rate: float) -> LearnedParameter:
            weight_bar = self_bar.weight
            weight = self.weight - weight_bar * learning_rate
            if self.constrain_positive:
                weight = jnp.maximum(weight, self.minimum_positive_weight)
            return LearnedParameter(weight=weight,
                                    constrain_positive=self.constrain_positive)

    def loss(w: LearnedParameter) -> float:
        return jnp.square(w.weight - 3.3)

    w = LearnedParameter(2.0, True)
    w_bar = grad(loss)(w)
    new_w = w.trained(w_bar, 1e-4)
    ```

    `dataclass` includes a convenient replace method::

        w.replace(weight=3.4)

    Since this dataclass is a pytree, all of JAX's functions that accept pytrees work with it,
    including iteration, differentiation, and `jax.tree_util` functions.

    Another benefit is the display of dataclasses.  `print(new_w)` gives::
    ```
    LearnedParameter
        weight=Jax Array ()
                2.0003
        constrain_positive=True
    ```
    """
    # pylint: disable=protected-access

    # Apply dataclass function to clz.
    data_clz: Type[T] = dataclasses.dataclass(frozen=True)(clz)  # type: ignore

    # Partition fields into hashed, tree, and uninitialized.
    hashed_fields: List[str] = []
    tree_fields: List[str] = []

    for field_info in dataclasses.fields(data_clz):  # type: ignore
        if not field_info.init:
            continue
        if field_info.metadata.get('pytree_like', True):
            tree_fields.append(field_info.name)
        else:
            hashed_fields.append(field_info.name)

    # Generate additional methods.
    def __repr__(self: T) -> str:
        return str(self.display())

    def display(self: T, show_values: bool = True, indent: int = 0) -> str:
        retval = display_class(type(self))
        for field_info in dataclasses.fields(data_clz):  # type: ignore
            retval += display_key_and_value(
                field_info.name, getattr(self, field_info.name), "=", show_values, indent)
        return retval

    def tree.flatten(x: T) -> Tuple[Sequence[PyTree], Hashable]:
        hashed = tuple(getattr(x, name) for name in hashed_fields)
        trees = tuple(getattr(x, name) for name in tree_fields)
        return trees, hashed

    def tree.unflatten(cls: Type[T], hashed: Hashable, trees: Sequence[PyTree]) -> T:
        if not isinstance(hashed, tuple):
            raise TypeError
        hashed_args = dict(zip(hashed_fields, hashed))
        tree_args = dict(zip(tree_fields, trees))
        return cls(**hashed_args, **tree_args)

    # Assign methods to the class.
    data_clz.__repr__ = __repr__  # type: ignore
    data_clz.display = display  # type: ignore
    data_clz.tree.flatten = tree.flatten  # type: ignore
    data_clz.tree.unflatten = classmethod(tree.unflatten)  # type: ignore

    # Assign field lists to the class.
    data_clz.tree_fields = tree_fields  # type: ignore
    data_clz.hashed_fields = hashed_fields  # type: ignore

    # Register the class as a JAX PyTree.
    register_pytree_node(data_clz, tree.flatten, data_clz.tree.unflatten)  # type: ignore

    # Verify that the generated class is PyTreeLike.
    assert isinstance(data_clz, PyTreeLike)

    return data_clz
def field(pytree_like: bool = True, **kwargs: Any) ‑> cooperative_dataclasses.dataclasses.Field

Args

pytree_like
Indicates whether a field is a pytree or static. Pytree fields are differentiated and traced.
kwargs
Any of the keyword arguments from dataclasses.field.
Expand source code
def field(pytree_like: bool = True, **kwargs: Any) -> dataclasses.Field:
    """
    Args:
        pytree_like: Indicates whether a field is a pytree or static.  Pytree fields are
            differentiated and traced.
        kwargs: Any of the keyword arguments from `dataclasses.field`.
    """
    return dataclasses.field(metadata={**kwargs.pop('metadata', {}),
                                       'pytree_like': pytree_like},
                             **kwargs)
def is_scalar(x: Any) ‑> bool
Expand source code
def is_scalar(x: Any) -> bool:
    return isinstance(x, Number) or isinstance(x, (np.ndarray, jnp.ndarray)) and x.shape == ()
def jax_allclose(actual: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, ForwardRef('PyTreeLike'), Tuple[ForwardRef('PyTree'), ...], List[ForwardRef('PyTree')], Dict[Hashable, ForwardRef('PyTree')], NoneType], desired: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, ForwardRef('PyTreeLike'), Tuple[ForwardRef('PyTree'), ...], List[ForwardRef('PyTree')], Dict[Hashable, ForwardRef('PyTree')], NoneType], rtol: Union[float, NoneType] = None, atol: Union[float, NoneType] = None) ‑> bool

Args

actual
The actual value.
desired
The desired value.
rtol
The relative tolerance of the comparisons in the comparison.
atol
The absolute tolerance of the comparisons in the comparison.
Expand source code
def jax_allclose(actual: PyTree,
                 desired: PyTree,
                 rtol: Optional[float] = None,
                 atol: Optional[float] = None) -> bool:
    """
    Args:
        actual: The actual value.
        desired: The desired value.
        rtol: The relative tolerance of the comparisons in the comparison.
        atol: The absolute tolerance of the comparisons in the comparison.
    """
    if rtol is None:
        rtol = default_rtol
    if atol is None:
        atol = default_atol

    return cast(
        bool,
        tree.reduce(jnp.logical_and,
                    tree_multimap(partial(np.allclose, rtol=rtol, atol=atol), actual, desired),
                    True))
def print_generic(*args: Any, **kwargs: Any) ‑> NoneType
Expand source code
def print_generic(*args: Any, **kwargs: Any) -> None:
    for value in args:
        print(display_generic(value))
    for key, value in kwargs.items():
        print(display_key_and_value(key, value, "=", True, 0))
def sum_tensors(tensors: Collection[Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]], shape: Union[int, Sequence[int], NoneType] = None) ‑> Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]
Expand source code
def sum_tensors(tensors: Collection[Tensor],
                shape: Optional[ShapeLike] = None) -> Tensor:
    if not tensors:
        return jnp.zeros(shape)
    return reduce(add, tensors)

Classes

class Displayable (*args, **kwds)

This protocol identifies classes that support the display_generic mechanism.

Expand source code
class Displayable(Protocol):
    """
    This protocol identifies classes that support the `display_generic` mechanism.
    """
    def display(self, show_values: bool = True, indent: int = 0) -> str:
        ...

Ancestors

  • typing.Protocol
  • typing.Generic

Methods

def display(self, show_values: bool = True, indent: int = 0) ‑> str
Expand source code
def display(self, show_values: bool = True, indent: int = 0) -> str:
    ...
class Generator (*, seed: Optional[int] = None, key: Optional[Tensor] = None, **kwargs: Any)

This class represents a JAX random number generator. Unlike numpy.Generator, Generator has no mutating methods. Instead, its generation methods return a new instance along with the generated tensor.

Expand source code
class Generator:
    """
    This class represents a JAX random number generator.  Unlike `numpy.Generator`, `tjax.Generator`
    has no mutating methods.  Instead, its generation methods return a new instance along with
    the generated tensor.
    """

    key: Tensor

    def __init__(self,
                 *,
                 seed: Optional[int] = None,
                 key: Optional[Tensor] = None,
                 **kwargs: Any):
        super().__init__(**kwargs)
        if key is None:
            if seed is None:
                raise ValueError
            key = jax.random.PRNGKey(seed)
        object.__setattr__(self, 'key', key)

    # New methods ----------------------------------------------------------------------------------
    def split(self, n: int = 2) -> List[Generator]:
        keys = jax.random.split(self.key, n)
        return [Generator(key=key) for key in keys]

    def normal(self, std_dev: Tensor, shape: Shape = ()) -> (
            Tuple[Generator, Tensor]):
        g1, g2 = self.split()
        return g1, std_dev * jax.random.normal(g2.key, shape)

    def gamma(self, gamma_shape: Tensor, shape: Shape = ()) -> (
            Tuple[Generator, Tensor]):
        g1, g2 = self.split()
        return g1, jax.random.gamma(g2.key, gamma_shape, shape)

    # Magic methods --------------------------------------------------------------------------------
    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, Generator):
            return NotImplemented
        return jnp.all(self.key == other.key)

    def __ne__(self, other: Any) -> bool:
        if not isinstance(other, Generator):
            return NotImplemented
        return not self.__eq__(other)

    def __hash__(self) -> int:
        return hash((int(self.key[0]), int(self.key[1])))

Methods

def gamma(self, gamma_shape: Tensor, shape: Shape = ()) ‑> Tuple[Generator, Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]]
Expand source code
def gamma(self, gamma_shape: Tensor, shape: Shape = ()) -> (
        Tuple[Generator, Tensor]):
    g1, g2 = self.split()
    return g1, jax.random.gamma(g2.key, gamma_shape, shape)
def normal(self, std_dev: Tensor, shape: Shape = ()) ‑> Tuple[Generator, Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]]
Expand source code
def normal(self, std_dev: Tensor, shape: Shape = ()) -> (
        Tuple[Generator, Tensor]):
    g1, g2 = self.split()
    return g1, std_dev * jax.random.normal(g2.key, shape)
def split(self, n: int = 2) ‑> List[Generator]
Expand source code
def split(self, n: int = 2) -> List[Generator]:
    keys = jax.random.split(self.key, n)
    return [Generator(key=key) for key in keys]
class LogCotangent (cotangent: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray], **kwargs)

LogCotangent logs cotangents in a differentiated jax function. For example:

from jax import grad
from jax import numpy as jnp

from tjax import LogCotangent

def loss(x, w, log_cotangent):
    y = x * w
    z = log_cotangent.forward(y)
    return jnp.sum(jnp.square(2.0 - z))

x = jnp.array([1.0, 2.0])
w = jnp.array([2.2, 3.5])
lg_bar = grad(loss, 2)(x, w, LogCotangent.create(shape=x.shape))

# lg_bar.cotangent now holds the transmitted cotangent.
Expand source code
class LogCotangent:
    """
    LogCotangent logs cotangents in a differentiated jax function.  For example:

        from jax import grad
        from jax import numpy as jnp

        from tjax import LogCotangent

        def loss(x, w, log_cotangent):
            y = x * w
            z = log_cotangent.forward(y)
            return jnp.sum(jnp.square(2.0 - z))

        x = jnp.array([1.0, 2.0])
        w = jnp.array([2.2, 3.5])
        lg_bar = grad(loss, 2)(x, w, LogCotangent.create(shape=x.shape))

        # lg_bar.cotangent now holds the transmitted cotangent.
    """
    cotangent: Tensor

    @classmethod
    def create(cls: Type[T], shape: Tuple[int, ...]) -> T:
        """
        Factory to create LogCotangent object.__class__(
        Args:
            shape: The shape of the transmitted tensor and its cotangent.
        """
        return LogCotangent(jnp.zeros(shape))

    def forward(self, x: Tensor) -> Tensor:
        """
        This method is called in the forward pass.  It will automatically log the cotangent in the
        backward pass.

        Args:
            x: The tensor to be transmitted.
        Returns: The same tensor that was inputted, x.
        """
        if x.shape != self.cotangent.shape:
            raise ValueError
        return x + self.cotangent

Static methods

def create(shape: Tuple[int, ...]) ‑> ~T

Factory to create LogCotangent object.class(

Args

shape
The shape of the transmitted tensor and its cotangent.
Expand source code
@classmethod
def create(cls: Type[T], shape: Tuple[int, ...]) -> T:
    """
    Factory to create LogCotangent object.__class__(
    Args:
        shape: The shape of the transmitted tensor and its cotangent.
    """
    return LogCotangent(jnp.zeros(shape))

Methods

def forward(self, x: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]) ‑> Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]

This method is called in the forward pass. It will automatically log the cotangent in the backward pass.

Args

x
The tensor to be transmitted.

Returns: The same tensor that was inputted, x.

Expand source code
def forward(self, x: Tensor) -> Tensor:
    """
    This method is called in the forward pass.  It will automatically log the cotangent in the
    backward pass.

    Args:
        x: The tensor to be transmitted.
    Returns: The same tensor that was inputted, x.
    """
    if x.shape != self.cotangent.shape:
        raise ValueError
    return x + self.cotangent
class Partial (func: Callable[..., R], /, *args: Any, callable_is_static: bool = True, static_argnums: Tuple[int, ...] = (), static_kwargs: Mapping[str, Any] = {}, **kwargs: Any)

A version of functools.partial that returns a pytree.

Use it for partial function evaluation in a way that is compatible with JAX's transformations, e.g., Partial(func, *args, **kwargs).

Expand source code
class Partial(partial, Generic[R]):
    """
    A version of functools.partial that returns a pytree.

    Use it for partial function evaluation in a way that is compatible with JAX's transformations,
    e.g., ``Partial(func, *args, **kwargs)``.
    """

    callable_is_static: bool
    static_argnums: Tuple[int, ...]
    static_kwargs: Mapping[str, Any]

    def __new__(cls: Type[T],
                func: Callable[..., R],
                /,
                *args: Any,
                callable_is_static: bool = True,
                static_argnums: Tuple[int, ...] = (),
                static_kwargs: Mapping[str, Any] = {},
                **kwargs: Any) -> T:
        """
        Args:
            func: The function being applied.
            args: The applied positional arguments.
            callable_is_static: Whether the function callable is static.
            static_argnums: The indices of the applied positional arguments that are static.
            static_kwargs: The key-value pairs representing applied keyword arguments that are
                static.
            kwargs: The applied keyword arguments.
        """
        if callable_is_static and isinstance(func, Partial):
            raise TypeError
        retval = super().__new__(cls, func, *args, **kwargs)  # type: ignore
        retval.callable_is_static = callable_is_static
        retval.static_argnums = set(static_argnums)
        retval.static_kwargs = static_kwargs
        return retval

    def tree.flatten(self: Partial[R]) -> Tuple[Sequence[PyTree], Hashable]:
        static_args = []
        tree_args = []

        def _append(is_static: bool, value: Any) -> None:
            if is_static:
                static_args.append(value)
            else:
                tree_args.append(value)

        _append(self.callable_is_static, self.func)
        for i, value in enumerate(self.args):
            _append(i in self.static_argnums, value)

        return ((list(reversed(tree_args)), self.keywords),
                (self.callable_is_static, self.static_argnums,
                 list(reversed(static_args)), self.static_kwargs))

    @classmethod
    def tree.unflatten(cls: Type[R],
                       static: Hashable,
                       trees: Sequence[PyTree]) -> Partial[R]:
        if not isinstance(static, Iterable):
            raise RuntimeError

        callable_is_static, static_argnums, static_args, static_kwargs = static

        if not isinstance(static_args, list):
            raise RuntimeError

        tree_args, tree_kwargs = trees

        if not isinstance(tree_args, list):
            raise RuntimeError
        if not isinstance(tree_kwargs, dict):
            raise RuntimeError

        tree_kwargs = cast(Dict[str, Any], tree_kwargs)

        args = []
        for i in range(len(static_args) + len(tree_args)):
            if i == 0:
                is_static = callable_is_static
            else:
                is_static = i - 1 in static_argnums
            if is_static:
                args.append(static_args.pop())
            else:
                args.append(tree_args.pop())

        return Partial[R](*args,
                          callable_is_static=callable_is_static,
                          static_argnums=static_argnums,
                          static_kwargs=static_kwargs,
                          **tree_kwargs)

    def __call__(self, *args: Any, **kwargs: Any) -> R:
        return super().__call__(*args, **self.static_kwargs, **kwargs)

Ancestors

  • functools.partial
  • typing.Generic
class custom_vjp (*args, **kwds)

Set up a JAX-transformable function for a custom VJP rule definition.

This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reverse-mode differentiation transformation (like jax.grad()) is applied, in which case a custom user-supplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method, defvjp, which defines the custom VJP rule.

This decorator precludes the use of forward-mode automatic differentiation.

This is a shim class to work around an issue with JAX's custom_vjp. It provides both:

  • static arguments, and
  • nondifferentiable arguments.

Static arguments are passed in to both the forward and the backward pass. They must be hashable. Different values for static arguments will generate recompilation.

The generated backward pass will generate zeroed-out cotangents. Ideally, no corresponding cotangents would be created, but such a change would have to be done in JAX itself.

For example::

from tjax import custom_vjp
from jax import numpy as jnp

@partial(custom_vjp, nondiff_argnums=2)
def f(x, y, z):
return jnp.sin(x) * y + z

def f_fwd(x, y, z):
return f(x, y, z), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(residuals, output_bar):
cos_x, sin_x, y = residuals
x_bar = cos_x * output_bar * y
y_bar = sin_x * output_bar
# z_bar is not returned because it's nondifferentiable.
return x_bar, y_bar

f.defvjp(f_fwd, f_bwd)

Args

fun
the function to decorate.
static_argnums
The indices of the static arguments.
nondiff_argnums
The indices of the nondifferentiable arguments.
Expand source code
class custom_vjp(Generic[R]):
    """
    Set up a JAX-transformable function for a custom VJP rule definition.

    This class is meant to be used as a function decorator. Instances are callables that behave
    similarly to the underlying function to which the decorator was applied, except when a
    reverse-mode differentiation transformation (like `jax.grad()`) is applied, in which case a
    custom user-supplied VJP rule function is used instead of tracing into and performing automatic
    differentiation of the underlying function’s implementation. There is a single instance method,
    defvjp, which defines the custom VJP rule.

    This decorator precludes the use of forward-mode automatic differentiation.

    This is a shim class to work around an
    [issue with JAX's custom_vjp](https://github.com/google/jax/issues/2912).  It provides both:

    - static arguments, and
    - nondifferentiable arguments.

    Static arguments are passed in to both the forward and the backward pass.  They must be
    hashable.  Different values for static arguments will generate recompilation.

    The generated backward pass will generate zeroed-out cotangents.  Ideally, no corresponding
    cotangents would be created, but such a change would have to be done in JAX itself.

    For example::

        from tjax import custom_vjp
        from jax import numpy as jnp

        @partial(custom_vjp, nondiff_argnums=2)
        def f(x, y, z):
        return jnp.sin(x) * y + z

        def f_fwd(x, y, z):
        return f(x, y, z), (jnp.cos(x), jnp.sin(x), y)

        def f_bwd(residuals, output_bar):
        cos_x, sin_x, y = residuals
        x_bar = cos_x * output_bar * y
        y_bar = sin_x * output_bar
        # z_bar is not returned because it's nondifferentiable.
        return x_bar, y_bar

        f.defvjp(f_fwd, f_bwd)
    """
    def __init__(self,
                 fun: Callable[..., R],
                 static_argnums: Union[int, Tuple[int, ...]] = (),
                 nondiff_argnums: Union[int, Tuple[int, ...]] = ()):
        """
        Args:
            fun: the function to decorate.
            static_argnums: The indices of the static arguments.
            nondiff_argnums: The indices of the nondifferentiable arguments.
        """
        static_argnums = as_tuple(static_argnums)
        nondiff_argnums = as_tuple(nondiff_argnums)
        if intersection := set(static_argnums) & set(nondiff_argnums):
            raise ValueError(
                f"Arguments {intersection} cannot be both static and nondifferentiable.")
        self.nondiff_argnums = nondiff_argnums
        self.vjp = jax_custom_vjp(fun, nondiff_argnums=static_argnums)

    def defvjp(self, fwd: Callable[..., Tuple[R, Any]], bwd: Callable[..., Any]) -> None:
        """
        Implement the custom forward and backward passes of the custom derivative.

        Args:
            fwd: The custom forward pass.
            bwd: The custom backward pass.  Cotangents for the nondifferentiable arguments should
                not be provided by the user-provided backward pass.
        """
        def new_fwd(*args: Any) -> Tuple[R, Any]:
            zeroed_args = tuple([tree.map(jnp.zeros_like, args[i])
                                 for i in self.nondiff_argnums])
            primal, internal_residuals = fwd(*args)
            return primal, (zeroed_args, internal_residuals)

        def new_bwd(residuals: Any, output_bar: R) -> Any:
            zeroed_args, internal_residuals = residuals
            input_bar = bwd(internal_residuals, output_bar)
            input_bar = list(input_bar)
            for i, index in enumerate(self.nondiff_argnums):
                input_bar[index: index] = [zeroed_args[i]]
            return tuple(input_bar)

        self.vjp.defvjp(new_fwd, new_bwd)

    def __call__(self, *args: Any) -> R:
        return self.vjp(*args)

    def __get__(self, instance: Any, owner: Any = None) -> Callable[..., R]:
        # https://github.com/google/jax/issues/2483
        return self.vjp.__get__(instance, owner)

Ancestors

  • typing.Generic

Methods

def defvjp(self, fwd: Callable[..., Tuple[~R, Any]], bwd: Callable[..., Any]) ‑> NoneType

Implement the custom forward and backward passes of the custom derivative.

Args

fwd
The custom forward pass.
bwd
The custom backward pass. Cotangents for the nondifferentiable arguments should not be provided by the user-provided backward pass.
Expand source code
def defvjp(self, fwd: Callable[..., Tuple[R, Any]], bwd: Callable[..., Any]) -> None:
    """
    Implement the custom forward and backward passes of the custom derivative.

    Args:
        fwd: The custom forward pass.
        bwd: The custom backward pass.  Cotangents for the nondifferentiable arguments should
            not be provided by the user-provided backward pass.
    """
    def new_fwd(*args: Any) -> Tuple[R, Any]:
        zeroed_args = tuple([tree.map(jnp.zeros_like, args[i])
                             for i in self.nondiff_argnums])
        primal, internal_residuals = fwd(*args)
        return primal, (zeroed_args, internal_residuals)

    def new_bwd(residuals: Any, output_bar: R) -> Any:
        zeroed_args, internal_residuals = residuals
        input_bar = bwd(internal_residuals, output_bar)
        input_bar = list(input_bar)
        for i, index in enumerate(self.nondiff_argnums):
            input_bar[index: index] = [zeroed_args[i]]
        return tuple(input_bar)

    self.vjp.defvjp(new_fwd, new_bwd)