Package tjax
This library implements a variety of tools for the differential programming library JAX.
Expand source code
"""
This library implements a variety of tools for the differential programming library
[JAX](https://github.com/google/jax).
"""
from .annotations import *
from .color_stub import *
from .dataclass import *
from .display import *
from .dtypes import *
from .generator import *
from .graph import *
from .log_cotangent import *
from .partial import *
from .pytree_like import *
from .shims import *
from .testing import *
from .tools import *
__pdoc__ = {}
__pdoc__['real_dtype'] = False
__pdoc__['complex_dtype'] = False
__pdoc__['PyTreeLike'] = False
__pdoc__['Field'] = False
__pdoc__['InitVar'] = False
__pdoc__['FrozenInstanceError'] = False
from .dataclass import document_dataclass
document_dataclass(__pdoc__, 'Generator')
document_dataclass(__pdoc__, 'Partial')
document_dataclass(__pdoc__, 'LogCotangent')
del document_dataclass
__all__ = list(locals())
Sub-modules
tjax.annotations
tjax.color_stub
tjax.display
tjax.dtypes
tjax.generator
tjax.graph
tjax.log_cotangent
tjax.partial
tjax.pytree_like
tjax.shims
tjax.testing
tjax.tools
Functions
def assert_jax_allclose(actual: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, ForwardRef('PyTreeLike'), Tuple[ForwardRef('PyTree'), ...], List[ForwardRef('PyTree')], Dict[Hashable, ForwardRef('PyTree')], NoneType], desired: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, ForwardRef('PyTreeLike'), Tuple[ForwardRef('PyTree'), ...], List[ForwardRef('PyTree')], Dict[Hashable, ForwardRef('PyTree')], NoneType], original_name: Union[str, NoneType] = None, original_value: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, ForwardRef('PyTreeLike'), Tuple[ForwardRef('PyTree'), ...], List[ForwardRef('PyTree')], Dict[Hashable, ForwardRef('PyTree')], NoneType] = None, *, rtol: Union[float, NoneType] = None, atol: Union[float, NoneType] = None)
-
Asserts that every tensor in an actual pytree matches the corresponding tensor in a desired pytree. If the assertion fails, a passing test string is printed::
from tjax import assert_jax_allclose, dataclass, Tensor @dataclass class A: x: Tensor y: Tensor @dataclass class B: z: A original = B(A(1.2, 3.4)) desired = B(A(3.0, 4.0)) actual = B(A(1.2, 5.2)) assert_jax_allclose(actual, desired, 'original', original)
This prints::
JAX trees don't match. Actual: B z=A x=3.0 y=4.0 Desired: B z=A x=1.2 y=5.2 Test string: original.replace(z=original.z.replace(x=3.0, y=4.0))
The test string can then be pasted.
Args
actual
- The actual value.
desired
- The desired value.
original_name
- The variable name that contains the original value.
original_value
- The original value. This is usually a pytree like a dataclass that has the same type as actual and desired, but contains different values.
rtol
- The relative tolerance of the comparisons in the assertion.
atol
- The absolute tolerance of the comparisons in the assertion.
Expand source code
def assert_jax_allclose(actual: PyTree, desired: PyTree, original_name: Optional[str] = None, original_value: Optional[PyTree] = None, *, rtol: Optional[float] = None, atol: Optional[float] = None) -> None: """ Asserts that every tensor in an actual pytree matches the corresponding tensor in a desired pytree. If the assertion fails, a passing test string is printed:: ```python from tjax import assert_jax_allclose, dataclass, Tensor @dataclass class A: x: Tensor y: Tensor @dataclass class B: z: A original = B(A(1.2, 3.4)) desired = B(A(3.0, 4.0)) actual = B(A(1.2, 5.2)) assert_jax_allclose(actual, desired, 'original', original) ``` This prints:: ``` JAX trees don't match. Actual: B z=A x=3.0 y=4.0 Desired: B z=A x=1.2 y=5.2 Test string: original.replace(z=original.z.replace(x=3.0, y=4.0)) ``` The test string can then be pasted. Args: actual: The actual value. desired: The desired value. original_name: The variable name that contains the original value. original_value: The original value. This is usually a pytree like a dataclass that has the same type as actual and desired, but contains different values. rtol: The relative tolerance of the comparisons in the assertion. atol: The absolute tolerance of the comparisons in the assertion. """ if rtol is None: rtol = default_rtol if atol is None: atol = default_atol try: tree_multimap(partial(np.testing.assert_allclose, rtol=rtol, atol=atol), actual, desired) except Exception: print("JAX trees don't match. Actual:") print(actual) print("Desired:") print(desired) if original_name is not None and original_value is not None: print("Test string:") print(get_test_string(original_name, actual, original_value, rtol, atol)) raise
def dataclass(clz: Type[~T]) ‑> Type[~T]
-
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 annotations to determine fields. Default values for fields are provided using assignment. To mark fields as JAX static fields rather than JAX pytree fields, use the
field()
function.For example::
from __future__ import annotations from typing import ClassVar from tjax import dataclass, field, Tensor from jax import numpy as jnp from jax import grad @dataclass class LearnedParameter: weight: Tensor constrain_positive: bool = field(pytree_like=False) minimum_positive_weight: ClassVar[Tensor] = 1e-6 def trained(self, self_bar: LearnedParameter, learning_rate: float) -> LearnedParameter: weight_bar = self_bar.weight weight = self.weight - weight_bar * learning_rate if self.constrain_positive: weight = jnp.maximum(weight, self.minimum_positive_weight) return LearnedParameter(weight=weight, constrain_positive=self.constrain_positive) def loss(w: LearnedParameter) -> float: return jnp.square(w.weight - 3.3) w = LearnedParameter(2.0, True) w_bar = grad(loss)(w) new_w = w.trained(w_bar, 1e-4)
dataclass()
includes a convenient replace method::w.replace(weight=3.4)
Since this dataclass is a pytree, all of JAX's functions that accept pytrees work with it, including iteration, differentiation, and
jax.tree_util
functions.Another benefit is the display of dataclasses.
print(new_w)
gives::LearnedParameter weight=Jax Array () 2.0003 constrain_positive=True
Expand source code
def dataclass(clz: Type[T]) -> Type[T]: """ Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. Examines PEP 526 annotations to determine fields. Default values for fields are provided using assignment. To mark fields as JAX static fields rather than JAX pytree fields, use the `field` function. For example:: ```python from __future__ import annotations from typing import ClassVar from tjax import dataclass, field, Tensor from jax import numpy as jnp from jax import grad @dataclass class LearnedParameter: weight: Tensor constrain_positive: bool = field(pytree_like=False) minimum_positive_weight: ClassVar[Tensor] = 1e-6 def trained(self, self_bar: LearnedParameter, learning_rate: float) -> LearnedParameter: weight_bar = self_bar.weight weight = self.weight - weight_bar * learning_rate if self.constrain_positive: weight = jnp.maximum(weight, self.minimum_positive_weight) return LearnedParameter(weight=weight, constrain_positive=self.constrain_positive) def loss(w: LearnedParameter) -> float: return jnp.square(w.weight - 3.3) w = LearnedParameter(2.0, True) w_bar = grad(loss)(w) new_w = w.trained(w_bar, 1e-4) ``` `dataclass` includes a convenient replace method:: w.replace(weight=3.4) Since this dataclass is a pytree, all of JAX's functions that accept pytrees work with it, including iteration, differentiation, and `jax.tree_util` functions. Another benefit is the display of dataclasses. `print(new_w)` gives:: ``` LearnedParameter weight=Jax Array () 2.0003 constrain_positive=True ``` """ # pylint: disable=protected-access # Apply dataclass function to clz. data_clz: Type[T] = dataclasses.dataclass(frozen=True)(clz) # type: ignore # Partition fields into hashed, tree, and uninitialized. hashed_fields: List[str] = [] tree_fields: List[str] = [] for field_info in dataclasses.fields(data_clz): # type: ignore if not field_info.init: continue if field_info.metadata.get('pytree_like', True): tree_fields.append(field_info.name) else: hashed_fields.append(field_info.name) # Generate additional methods. def __repr__(self: T) -> str: return str(self.display()) def display(self: T, show_values: bool = True, indent: int = 0) -> str: retval = display_class(type(self)) for field_info in dataclasses.fields(data_clz): # type: ignore retval += display_key_and_value( field_info.name, getattr(self, field_info.name), "=", show_values, indent) return retval def tree.flatten(x: T) -> Tuple[Sequence[PyTree], Hashable]: hashed = tuple(getattr(x, name) for name in hashed_fields) trees = tuple(getattr(x, name) for name in tree_fields) return trees, hashed def tree.unflatten(cls: Type[T], hashed: Hashable, trees: Sequence[PyTree]) -> T: if not isinstance(hashed, tuple): raise TypeError hashed_args = dict(zip(hashed_fields, hashed)) tree_args = dict(zip(tree_fields, trees)) return cls(**hashed_args, **tree_args) # Assign methods to the class. data_clz.__repr__ = __repr__ # type: ignore data_clz.display = display # type: ignore data_clz.tree.flatten = tree.flatten # type: ignore data_clz.tree.unflatten = classmethod(tree.unflatten) # type: ignore # Assign field lists to the class. data_clz.tree_fields = tree_fields # type: ignore data_clz.hashed_fields = hashed_fields # type: ignore # Register the class as a JAX PyTree. register_pytree_node(data_clz, tree.flatten, data_clz.tree.unflatten) # type: ignore # Verify that the generated class is PyTreeLike. assert isinstance(data_clz, PyTreeLike) return data_clz
def field(pytree_like: bool = True, **kwargs: Any) ‑> cooperative_dataclasses.dataclasses.Field
-
Args
pytree_like
- Indicates whether a field is a pytree or static. Pytree fields are differentiated and traced.
kwargs
- Any of the keyword arguments from
dataclasses.field
.
Expand source code
def field(pytree_like: bool = True, **kwargs: Any) -> dataclasses.Field: """ Args: pytree_like: Indicates whether a field is a pytree or static. Pytree fields are differentiated and traced. kwargs: Any of the keyword arguments from `dataclasses.field`. """ return dataclasses.field(metadata={**kwargs.pop('metadata', {}), 'pytree_like': pytree_like}, **kwargs)
def is_scalar(x: Any) ‑> bool
-
Expand source code
def is_scalar(x: Any) -> bool: return isinstance(x, Number) or isinstance(x, (np.ndarray, jnp.ndarray)) and x.shape == ()
def jax_allclose(actual: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, ForwardRef('PyTreeLike'), Tuple[ForwardRef('PyTree'), ...], List[ForwardRef('PyTree')], Dict[Hashable, ForwardRef('PyTree')], NoneType], desired: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, ForwardRef('PyTreeLike'), Tuple[ForwardRef('PyTree'), ...], List[ForwardRef('PyTree')], Dict[Hashable, ForwardRef('PyTree')], NoneType], rtol: Union[float, NoneType] = None, atol: Union[float, NoneType] = None) ‑> bool
-
Args
actual
- The actual value.
desired
- The desired value.
rtol
- The relative tolerance of the comparisons in the comparison.
atol
- The absolute tolerance of the comparisons in the comparison.
Expand source code
def jax_allclose(actual: PyTree, desired: PyTree, rtol: Optional[float] = None, atol: Optional[float] = None) -> bool: """ Args: actual: The actual value. desired: The desired value. rtol: The relative tolerance of the comparisons in the comparison. atol: The absolute tolerance of the comparisons in the comparison. """ if rtol is None: rtol = default_rtol if atol is None: atol = default_atol return cast( bool, tree.reduce(jnp.logical_and, tree_multimap(partial(np.allclose, rtol=rtol, atol=atol), actual, desired), True))
def print_generic(*args: Any, **kwargs: Any) ‑> NoneType
-
Expand source code
def print_generic(*args: Any, **kwargs: Any) -> None: for value in args: print(display_generic(value)) for key, value in kwargs.items(): print(display_key_and_value(key, value, "=", True, 0))
def sum_tensors(tensors: Collection[Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]], shape: Union[int, Sequence[int], NoneType] = None) ‑> Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]
-
Expand source code
def sum_tensors(tensors: Collection[Tensor], shape: Optional[ShapeLike] = None) -> Tensor: if not tensors: return jnp.zeros(shape) return reduce(add, tensors)
Classes
class Displayable (*args, **kwds)
-
This protocol identifies classes that support the
display_generic
mechanism.Expand source code
class Displayable(Protocol): """ This protocol identifies classes that support the `display_generic` mechanism. """ def display(self, show_values: bool = True, indent: int = 0) -> str: ...
Ancestors
- typing.Protocol
- typing.Generic
Methods
def display(self, show_values: bool = True, indent: int = 0) ‑> str
-
Expand source code
def display(self, show_values: bool = True, indent: int = 0) -> str: ...
class Generator (*, seed: Optional[int] = None, key: Optional[Tensor] = None, **kwargs: Any)
-
This class represents a JAX random number generator. Unlike
numpy.Generator
,Generator
has no mutating methods. Instead, its generation methods return a new instance along with the generated tensor.Expand source code
class Generator: """ This class represents a JAX random number generator. Unlike `numpy.Generator`, `tjax.Generator` has no mutating methods. Instead, its generation methods return a new instance along with the generated tensor. """ key: Tensor def __init__(self, *, seed: Optional[int] = None, key: Optional[Tensor] = None, **kwargs: Any): super().__init__(**kwargs) if key is None: if seed is None: raise ValueError key = jax.random.PRNGKey(seed) object.__setattr__(self, 'key', key) # New methods ---------------------------------------------------------------------------------- def split(self, n: int = 2) -> List[Generator]: keys = jax.random.split(self.key, n) return [Generator(key=key) for key in keys] def normal(self, std_dev: Tensor, shape: Shape = ()) -> ( Tuple[Generator, Tensor]): g1, g2 = self.split() return g1, std_dev * jax.random.normal(g2.key, shape) def gamma(self, gamma_shape: Tensor, shape: Shape = ()) -> ( Tuple[Generator, Tensor]): g1, g2 = self.split() return g1, jax.random.gamma(g2.key, gamma_shape, shape) # Magic methods -------------------------------------------------------------------------------- def __eq__(self, other: Any) -> bool: if not isinstance(other, Generator): return NotImplemented return jnp.all(self.key == other.key) def __ne__(self, other: Any) -> bool: if not isinstance(other, Generator): return NotImplemented return not self.__eq__(other) def __hash__(self) -> int: return hash((int(self.key[0]), int(self.key[1])))
Methods
def gamma(self, gamma_shape: Tensor, shape: Shape = ()) ‑> Tuple[Generator, Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]]
-
Expand source code
def gamma(self, gamma_shape: Tensor, shape: Shape = ()) -> ( Tuple[Generator, Tensor]): g1, g2 = self.split() return g1, jax.random.gamma(g2.key, gamma_shape, shape)
def normal(self, std_dev: Tensor, shape: Shape = ()) ‑> Tuple[Generator, Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]]
-
Expand source code
def normal(self, std_dev: Tensor, shape: Shape = ()) -> ( Tuple[Generator, Tensor]): g1, g2 = self.split() return g1, std_dev * jax.random.normal(g2.key, shape)
def split(self, n: int = 2) ‑> List[Generator]
-
Expand source code
def split(self, n: int = 2) -> List[Generator]: keys = jax.random.split(self.key, n) return [Generator(key=key) for key in keys]
class LogCotangent (cotangent: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray], **kwargs)
-
LogCotangent logs cotangents in a differentiated jax function. For example:
from jax import grad from jax import numpy as jnp from tjax import LogCotangent def loss(x, w, log_cotangent): y = x * w z = log_cotangent.forward(y) return jnp.sum(jnp.square(2.0 - z)) x = jnp.array([1.0, 2.0]) w = jnp.array([2.2, 3.5]) lg_bar = grad(loss, 2)(x, w, LogCotangent.create(shape=x.shape)) # lg_bar.cotangent now holds the transmitted cotangent.
Expand source code
class LogCotangent: """ LogCotangent logs cotangents in a differentiated jax function. For example: from jax import grad from jax import numpy as jnp from tjax import LogCotangent def loss(x, w, log_cotangent): y = x * w z = log_cotangent.forward(y) return jnp.sum(jnp.square(2.0 - z)) x = jnp.array([1.0, 2.0]) w = jnp.array([2.2, 3.5]) lg_bar = grad(loss, 2)(x, w, LogCotangent.create(shape=x.shape)) # lg_bar.cotangent now holds the transmitted cotangent. """ cotangent: Tensor @classmethod def create(cls: Type[T], shape: Tuple[int, ...]) -> T: """ Factory to create LogCotangent object.__class__( Args: shape: The shape of the transmitted tensor and its cotangent. """ return LogCotangent(jnp.zeros(shape)) def forward(self, x: Tensor) -> Tensor: """ This method is called in the forward pass. It will automatically log the cotangent in the backward pass. Args: x: The tensor to be transmitted. Returns: The same tensor that was inputted, x. """ if x.shape != self.cotangent.shape: raise ValueError return x + self.cotangent
Static methods
def create(shape: Tuple[int, ...]) ‑> ~T
-
Factory to create LogCotangent object.class(
Args
shape
- The shape of the transmitted tensor and its cotangent.
Expand source code
@classmethod def create(cls: Type[T], shape: Tuple[int, ...]) -> T: """ Factory to create LogCotangent object.__class__( Args: shape: The shape of the transmitted tensor and its cotangent. """ return LogCotangent(jnp.zeros(shape))
Methods
def forward(self, x: Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]) ‑> Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray]
-
This method is called in the forward pass. It will automatically log the cotangent in the backward pass.
Args
x
- The tensor to be transmitted.
Returns: The same tensor that was inputted, x.
Expand source code
def forward(self, x: Tensor) -> Tensor: """ This method is called in the forward pass. It will automatically log the cotangent in the backward pass. Args: x: The tensor to be transmitted. Returns: The same tensor that was inputted, x. """ if x.shape != self.cotangent.shape: raise ValueError return x + self.cotangent
class Partial (func: Callable[..., R], /, *args: Any, callable_is_static: bool = True, static_argnums: Tuple[int, ...] = (), static_kwargs: Mapping[str, Any] = {}, **kwargs: Any)
-
A version of functools.partial that returns a pytree.
Use it for partial function evaluation in a way that is compatible with JAX's transformations, e.g.,
Partial(func, *args, **kwargs)
.Expand source code
class Partial(partial, Generic[R]): """ A version of functools.partial that returns a pytree. Use it for partial function evaluation in a way that is compatible with JAX's transformations, e.g., ``Partial(func, *args, **kwargs)``. """ callable_is_static: bool static_argnums: Tuple[int, ...] static_kwargs: Mapping[str, Any] def __new__(cls: Type[T], func: Callable[..., R], /, *args: Any, callable_is_static: bool = True, static_argnums: Tuple[int, ...] = (), static_kwargs: Mapping[str, Any] = {}, **kwargs: Any) -> T: """ Args: func: The function being applied. args: The applied positional arguments. callable_is_static: Whether the function callable is static. static_argnums: The indices of the applied positional arguments that are static. static_kwargs: The key-value pairs representing applied keyword arguments that are static. kwargs: The applied keyword arguments. """ if callable_is_static and isinstance(func, Partial): raise TypeError retval = super().__new__(cls, func, *args, **kwargs) # type: ignore retval.callable_is_static = callable_is_static retval.static_argnums = set(static_argnums) retval.static_kwargs = static_kwargs return retval def tree.flatten(self: Partial[R]) -> Tuple[Sequence[PyTree], Hashable]: static_args = [] tree_args = [] def _append(is_static: bool, value: Any) -> None: if is_static: static_args.append(value) else: tree_args.append(value) _append(self.callable_is_static, self.func) for i, value in enumerate(self.args): _append(i in self.static_argnums, value) return ((list(reversed(tree_args)), self.keywords), (self.callable_is_static, self.static_argnums, list(reversed(static_args)), self.static_kwargs)) @classmethod def tree.unflatten(cls: Type[R], static: Hashable, trees: Sequence[PyTree]) -> Partial[R]: if not isinstance(static, Iterable): raise RuntimeError callable_is_static, static_argnums, static_args, static_kwargs = static if not isinstance(static_args, list): raise RuntimeError tree_args, tree_kwargs = trees if not isinstance(tree_args, list): raise RuntimeError if not isinstance(tree_kwargs, dict): raise RuntimeError tree_kwargs = cast(Dict[str, Any], tree_kwargs) args = [] for i in range(len(static_args) + len(tree_args)): if i == 0: is_static = callable_is_static else: is_static = i - 1 in static_argnums if is_static: args.append(static_args.pop()) else: args.append(tree_args.pop()) return Partial[R](*args, callable_is_static=callable_is_static, static_argnums=static_argnums, static_kwargs=static_kwargs, **tree_kwargs) def __call__(self, *args: Any, **kwargs: Any) -> R: return super().__call__(*args, **self.static_kwargs, **kwargs)
Ancestors
- functools.partial
- typing.Generic
class custom_vjp (*args, **kwds)
-
Set up a JAX-transformable function for a custom VJP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reverse-mode differentiation transformation (like
jax.grad()
) is applied, in which case a custom user-supplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method, defvjp, which defines the custom VJP rule.This decorator precludes the use of forward-mode automatic differentiation.
This is a shim class to work around an issue with JAX's custom_vjp. It provides both:
- static arguments, and
- nondifferentiable arguments.
Static arguments are passed in to both the forward and the backward pass. They must be hashable. Different values for static arguments will generate recompilation.
The generated backward pass will generate zeroed-out cotangents. Ideally, no corresponding cotangents would be created, but such a change would have to be done in JAX itself.
For example::
from tjax import custom_vjp from jax import numpy as jnp @partial(custom_vjp, nondiff_argnums=2) def f(x, y, z): return jnp.sin(x) * y + z def f_fwd(x, y, z): return f(x, y, z), (jnp.cos(x), jnp.sin(x), y) def f_bwd(residuals, output_bar): cos_x, sin_x, y = residuals x_bar = cos_x * output_bar * y y_bar = sin_x * output_bar # z_bar is not returned because it's nondifferentiable. return x_bar, y_bar f.defvjp(f_fwd, f_bwd)
Args
fun
- the function to decorate.
static_argnums
- The indices of the static arguments.
nondiff_argnums
- The indices of the nondifferentiable arguments.
Expand source code
class custom_vjp(Generic[R]): """ Set up a JAX-transformable function for a custom VJP rule definition. This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reverse-mode differentiation transformation (like `jax.grad()`) is applied, in which case a custom user-supplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method, defvjp, which defines the custom VJP rule. This decorator precludes the use of forward-mode automatic differentiation. This is a shim class to work around an [issue with JAX's custom_vjp](https://github.com/google/jax/issues/2912). It provides both: - static arguments, and - nondifferentiable arguments. Static arguments are passed in to both the forward and the backward pass. They must be hashable. Different values for static arguments will generate recompilation. The generated backward pass will generate zeroed-out cotangents. Ideally, no corresponding cotangents would be created, but such a change would have to be done in JAX itself. For example:: from tjax import custom_vjp from jax import numpy as jnp @partial(custom_vjp, nondiff_argnums=2) def f(x, y, z): return jnp.sin(x) * y + z def f_fwd(x, y, z): return f(x, y, z), (jnp.cos(x), jnp.sin(x), y) def f_bwd(residuals, output_bar): cos_x, sin_x, y = residuals x_bar = cos_x * output_bar * y y_bar = sin_x * output_bar # z_bar is not returned because it's nondifferentiable. return x_bar, y_bar f.defvjp(f_fwd, f_bwd) """ def __init__(self, fun: Callable[..., R], static_argnums: Union[int, Tuple[int, ...]] = (), nondiff_argnums: Union[int, Tuple[int, ...]] = ()): """ Args: fun: the function to decorate. static_argnums: The indices of the static arguments. nondiff_argnums: The indices of the nondifferentiable arguments. """ static_argnums = as_tuple(static_argnums) nondiff_argnums = as_tuple(nondiff_argnums) if intersection := set(static_argnums) & set(nondiff_argnums): raise ValueError( f"Arguments {intersection} cannot be both static and nondifferentiable.") self.nondiff_argnums = nondiff_argnums self.vjp = jax_custom_vjp(fun, nondiff_argnums=static_argnums) def defvjp(self, fwd: Callable[..., Tuple[R, Any]], bwd: Callable[..., Any]) -> None: """ Implement the custom forward and backward passes of the custom derivative. Args: fwd: The custom forward pass. bwd: The custom backward pass. Cotangents for the nondifferentiable arguments should not be provided by the user-provided backward pass. """ def new_fwd(*args: Any) -> Tuple[R, Any]: zeroed_args = tuple([tree.map(jnp.zeros_like, args[i]) for i in self.nondiff_argnums]) primal, internal_residuals = fwd(*args) return primal, (zeroed_args, internal_residuals) def new_bwd(residuals: Any, output_bar: R) -> Any: zeroed_args, internal_residuals = residuals input_bar = bwd(internal_residuals, output_bar) input_bar = list(input_bar) for i, index in enumerate(self.nondiff_argnums): input_bar[index: index] = [zeroed_args[i]] return tuple(input_bar) self.vjp.defvjp(new_fwd, new_bwd) def __call__(self, *args: Any) -> R: return self.vjp(*args) def __get__(self, instance: Any, owner: Any = None) -> Callable[..., R]: # https://github.com/google/jax/issues/2483 return self.vjp.__get__(instance, owner)
Ancestors
- typing.Generic
Methods
def defvjp(self, fwd: Callable[..., Tuple[~R, Any]], bwd: Callable[..., Any]) ‑> NoneType
-
Implement the custom forward and backward passes of the custom derivative.
Args
fwd
- The custom forward pass.
bwd
- The custom backward pass. Cotangents for the nondifferentiable arguments should not be provided by the user-provided backward pass.
Expand source code
def defvjp(self, fwd: Callable[..., Tuple[R, Any]], bwd: Callable[..., Any]) -> None: """ Implement the custom forward and backward passes of the custom derivative. Args: fwd: The custom forward pass. bwd: The custom backward pass. Cotangents for the nondifferentiable arguments should not be provided by the user-provided backward pass. """ def new_fwd(*args: Any) -> Tuple[R, Any]: zeroed_args = tuple([tree.map(jnp.zeros_like, args[i]) for i in self.nondiff_argnums]) primal, internal_residuals = fwd(*args) return primal, (zeroed_args, internal_residuals) def new_bwd(residuals: Any, output_bar: R) -> Any: zeroed_args, internal_residuals = residuals input_bar = bwd(internal_residuals, output_bar) input_bar = list(input_bar) for i, index in enumerate(self.nondiff_argnums): input_bar[index: index] = [zeroed_args[i]] return tuple(input_bar) self.vjp.defvjp(new_fwd, new_bwd)