File tree Expand file tree Collapse file tree 2 files changed +20
-0
lines changed Expand file tree Collapse file tree 2 files changed +20
-0
lines changed Original file line number Diff line number Diff line change @@ -540,6 +540,25 @@ def jax_autojit(
540
540
See Also
541
541
--------
542
542
jax.jit : JAX JIT compilation function.
543
+
544
+ Notes
545
+ -----
546
+ These are useful choices *for testing purposes only*, which is how this function is
547
+ intended to be used. The output of ``jax.jit`` is a C++ level callable, that
548
+ directly dispatches to the compiled kernel after the initial call. In comparison,
549
+ ``jax_autojit`` incurs in a much higher dispatch time.
550
+
551
+ Additionally, consider::
552
+
553
+ def f(x: Array, y: float, plus: bool) -> Array:
554
+ return x + y if plus else x - y
555
+
556
+ j1 = jax.jit(f, static_argnames="plus")
557
+ j2 = jax_autojit(f)
558
+
559
+ In the above example, ``j2`` requires a lot less setup to be tested effectively than
560
+ ``j1``, but on the flip side it means that it will be re-traced for every different
561
+ value of ``y``, which likely makes it not fit for purpose in production.
543
562
"""
544
563
import jax
545
564
Original file line number Diff line number Diff line change @@ -96,6 +96,7 @@ def lazy_xp_function( # type: ignore[explicit-any]
96
96
jax_jit : bool, optional
97
97
Set to True to replace `func` with a smart variant of ``jax.jit(func)`` after
98
98
calling the :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``.
99
+ This is the default behaviour.
99
100
Set to False if `func` is only compatible with eager (non-jitted) JAX.
100
101
101
102
Unlike with vanilla ``jax.jit``, all arguments and return types that are not JAX
You can’t perform that action at this time.
0 commit comments