Skip to content

Simplify dispatch of JAX random variables by handling rng split automatically #1315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 57 additions & 103 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,24 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
assert_size_argument_jax_compatible(node)

def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
rng["jax_state"] = rng_key
sample = jax_sample_fn(op, node=node)(
sampling_key, size, out_dtype, *parameters
)
return (rng, sample)

else:

def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
rng["jax_state"] = rng_key
sample = jax_sample_fn(op, node=node)(
sampling_key, static_size, out_dtype, *parameters
)
return (rng, sample)

return sample_fn

Expand All @@ -133,12 +143,9 @@ def jax_sample_fn_generic(op, node):
name = op.name
jax_op = getattr(jax.random, name)

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
rng["jax_state"] = rng_key
return (rng, sample)
def sample_fn(rng_key, size, dtype, *parameters):
sample = jax_op(rng_key, *parameters, shape=size, dtype=dtype)
return sample

return sample_fn

Expand All @@ -159,29 +166,23 @@ def jax_sample_fn_loc_scale(op, node):
name = op.name
jax_op = getattr(jax.random, name)

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, *parameters):
loc, scale = parameters
if size is None:
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape
sample = loc + jax_op(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
sample = loc + jax_op(rng_key, size, dtype) * scale
return sample

return sample_fn


@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_mvnormal(op, node):
def sample_fn(rng, size, dtype, mean, cov):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, mean, cov):
sample = jax.random.multivariate_normal(
sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method
rng_key, mean, cov, shape=size, dtype=dtype, method=op.method
)
rng["jax_state"] = rng_key
return (rng, sample)
return sample

return sample_fn

Expand All @@ -191,12 +192,9 @@ def jax_sample_fn_bernoulli(op, node):
"""JAX implementation of `BernoulliRV`."""

# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
def sample_fn(rng, size, dtype, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax.random.bernoulli(sampling_key, p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
def sample_fn(rng_key, size, dtype, p):
sample = jax.random.bernoulli(rng_key, p, shape=size)
return sample

return sample_fn

Expand All @@ -206,14 +204,10 @@ def jax_sample_fn_categorical(op, node):
"""JAX implementation of `CategoricalRV`."""

# We need a separate dispatch because Categorical expects logits in JAX
def sample_fn(rng, size, dtype, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

def sample_fn(rng_key, size, dtype, p):
logits = jax.scipy.special.logit(p)
sample = jax.random.categorical(sampling_key, logits=logits, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
sample = jax.random.categorical(rng_key, logits=logits, shape=size)
return sample

return sample_fn

Expand All @@ -233,15 +227,10 @@ def jax_sample_fn_uniform(op, node):
name = "randint"
jax_op = getattr(jax.random, name)

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, *parameters):
minval, maxval = parameters
sample = jax_op(
sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval
)
rng["jax_state"] = rng_key
return (rng, sample)
sample = jax_op(rng_key, shape=size, dtype=dtype, minval=minval, maxval=maxval)
return sample

return sample_fn

Expand All @@ -258,14 +247,11 @@ def jax_sample_fn_shape_scale(op, node):
name = op.name
jax_op = getattr(jax.random, name)

def sample_fn(rng, size, dtype, shape, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, shape, scale):
if size is None:
size = jax.numpy.broadcast_arrays(shape, scale)[0].shape
sample = jax_op(sampling_key, shape, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
sample = jax_op(rng_key, shape, size, dtype) * scale
return sample

return sample_fn

Expand All @@ -274,14 +260,11 @@ def sample_fn(rng, size, dtype, shape, scale):
def jax_sample_fn_exponential(op, node):
"""JAX implementation of `ExponentialRV`."""

def sample_fn(rng, size, dtype, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, scale):
if size is None:
size = jax.numpy.asarray(scale).shape
sample = jax.random.exponential(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
sample = jax.random.exponential(rng_key, size, dtype) * scale
return sample

return sample_fn

Expand All @@ -290,14 +273,11 @@ def sample_fn(rng, size, dtype, scale):
def jax_sample_fn_t(op, node):
"""JAX implementation of `StudentTRV`."""

def sample_fn(rng, size, dtype, df, loc, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, df, loc, scale):
if size is None:
size = jax.numpy.broadcast_arrays(df, loc, scale)[0].shape
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
return sample

return sample_fn

Expand All @@ -315,10 +295,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
"A default JAX rewrite should have materialized the implicit arange"
)

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

def sample_fn(rng_key, size, dtype, *parameters):
if op.has_p_param:
a, p, core_shape = parameters
else:
Expand All @@ -327,9 +304,7 @@ def sample_fn(rng, size, dtype, *parameters):
core_shape = tuple(np.asarray(core_shape)[(0,) * batch_ndim])

if batch_ndim == 0:
sample = jax.random.choice(
sampling_key, a, shape=core_shape, replace=False, p=p
)
sample = jax.random.choice(rng_key, a, shape=core_shape, replace=False, p=p)

else:
if size is None:
Expand All @@ -345,7 +320,7 @@ def sample_fn(rng, size, dtype, *parameters):
if p is not None:
p = jax.numpy.broadcast_to(p, size + p.shape[batch_ndim:])

batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
batch_sampling_keys = jax.random.split(rng_key, np.prod(size))

# Ravel the batch dimensions because vmap only works along a single axis
raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:])
Expand All @@ -366,8 +341,7 @@ def sample_fn(rng, size, dtype, *parameters):
# Reshape the batch dimensions
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])

rng["jax_state"] = rng_key
return (rng, sample)
return sample

return sample_fn

Expand All @@ -378,9 +352,7 @@ def jax_sample_fn_permutation(op, node):

batch_ndim = op.batch_ndim(node)

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, *parameters):
(x,) = parameters
if batch_ndim:
# jax.random.permutation has no concept of batch dims
Expand All @@ -389,17 +361,16 @@ def sample_fn(rng, size, dtype, *parameters):
else:
x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:])

batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))(
batch_sampling_keys, raveled_batch_x
)
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
else:
sample = jax.random.permutation(sampling_key, x)
sample = jax.random.permutation(rng_key, x)

rng["jax_state"] = rng_key
return (rng, sample)
return sample

return sample_fn

Expand All @@ -414,15 +385,9 @@ def jax_sample_fn_binomial(op, node):

from numpyro.distributions.util import binomial

def sample_fn(rng, size, dtype, n, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

sample = binomial(key=sampling_key, n=n, p=p, shape=size)

rng["jax_state"] = rng_key

return (rng, sample)
def sample_fn(rng_key, size, dtype, n, p):
sample = binomial(key=rng_key, n=n, p=p, shape=size)
return sample

return sample_fn

Expand All @@ -437,15 +402,9 @@ def jax_sample_fn_multinomial(op, node):

from numpyro.distributions.util import multinomial

def sample_fn(rng, size, dtype, n, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

sample = multinomial(key=sampling_key, n=n, p=p, shape=size)

rng["jax_state"] = rng_key

return (rng, sample)
def sample_fn(rng_key, size, dtype, n, p):
sample = multinomial(key=rng_key, n=n, p=p, shape=size)
return sample

return sample_fn

Expand All @@ -460,17 +419,12 @@ def jax_sample_fn_vonmises(op, node):

from numpyro.distributions.util import von_mises_centered

def sample_fn(rng, size, dtype, mu, kappa):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

def sample_fn(rng_key, size, dtype, mu, kappa):
sample = von_mises_centered(
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
key=rng_key, concentration=kappa, shape=size, dtype=dtype
)
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi

rng["jax_state"] = rng_key

return (rng, sample)
return sample

return sample_fn
2 changes: 1 addition & 1 deletion tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def rng_fn(cls, rng, size):
@jax_sample_fn.register(CustomRV)
def jax_sample_fn_custom(op, node):
def sample_fn(rng, size, dtype, *parameters):
return (rng, 0)
return 0

return sample_fn

Expand Down