BlackJAX and Oryx pattern

BlackJAX expects a JAX-compatible log density:

import jax, jax.numpy as jnp, blackjax

def logdensity(position):
    # return log p(parameters, data), up to a constant
    ...

warmup = blackjax.window_adaptation(blackjax.nuts, logdensity)
(last_state, params), _ = warmup.run(jax.random.key(1), initial_position, 1000)
kernel = blackjax.nuts(logdensity, **params).step

Oryx is relevant when it clarifies the generative model:

from oryx.core.ppl import random_variable, joint_log_prob

For routine ROS examples, CmdStanPy/PyData tools are usually clearer. BlackJAX/Oryx pages are included where they add value: hierarchical models, custom likelihoods, sampler mechanics, and simulation-heavy examples.