BlackJAX and Oryx pattern
BlackJAX expects a JAX-compatible log density:
import jax, jax.numpy as jnp, blackjax
def logdensity(position):
# return log p(parameters, data), up to a constant
...
warmup = blackjax.window_adaptation(blackjax.nuts, logdensity)
(last_state, params), _ = warmup.run(jax.random.key(1), initial_position, 1000)
kernel = blackjax.nuts(logdensity, **params).stepOryx is relevant when it clarifies the generative model:
from oryx.core.ppl import random_variable, joint_log_probFor routine ROS examples, CmdStanPy/PyData tools are usually clearer. BlackJAX/Oryx pages are included where they add value: hierarchical models, custom likelihoods, sampler mechanics, and simulation-heavy examples.