Skip to content

Commit e5aa280

Browse files
committed
DOC: autojit notes
1 parent 28a364d commit e5aa280

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,25 @@ def jax_autojit(
540540
See Also
541541
--------
542542
jax.jit : JAX JIT compilation function.
543+
544+
Notes
545+
-----
546+
These are useful choices *for testing purposes only*, which is how this function is
547+
intended to be used. The output of ``jax.jit`` is a C++ level callable, that
548+
directly dispatches to the compiled kernel after the initial call. In comparison,
549+
``jax_autojit`` incurs in a much higher dispatch time.
550+
551+
Additionally, consider::
552+
553+
def f(x: Array, y: float, plus: bool) -> Array:
554+
return x + y if plus else x - y
555+
556+
j1 = jax.jit(f, static_argnames="plus")
557+
j2 = jax_autojit(f)
558+
559+
In the above example, ``j2`` requires a lot less setup to be tested effectively than
560+
``j1``, but on the flip side it means that it will be re-traced for every different
561+
value of ``y``, which likely makes it not fit for purpose in production.
543562
"""
544563
import jax
545564

src/array_api_extra/testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def lazy_xp_function( # type: ignore[explicit-any]
9696
jax_jit : bool, optional
9797
Set to True to replace `func` with a smart variant of ``jax.jit(func)`` after
9898
calling the :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``.
99+
This is the default behaviour.
99100
Set to False if `func` is only compatible with eager (non-jitted) JAX.
100101
101102
Unlike with vanilla ``jax.jit``, all arguments and return types that are not JAX

0 commit comments

Comments
 (0)