Hi @fehiepsi,
thanks again for you help. I have my problems with finding the log probability in the trace. But I recreated the error in a small example and got get_trace() to run. Here is a rundown:
Sampling from SineBivariateVonMises
data = SineBivariateVonMises(0, 0, 1, 1, -1).sample(PRNGKey(42), (1000, ))
Model
@numpyro.handlers.reparam(
config={"phi_loc": CircularReparam(),
"psi_loc": CircularReparam(),
}
)
def min_example(data_2d):
phi_loc = sample('phi_loc', VonMises(0, 10))
psi_loc = sample('psi_loc', VonMises(0, 10))
phi_conc = sample('phi_conc', Beta(1, 1))
psi_conc = sample('psi_conc', Beta(1, 1))
conc = -sample('conc', HalfNormal(1))
depInd = SineBivariateVonMises(phi_loc, psi_loc, 70 * phi_conc, 70 * psi_conc, conc)
obs = sample('obs', depInd, obs=data_2d)
Run model
rng_key = PRNGKey(0)
num_warmup, num_samples = 100, 200
# Run NUTS.
kernel = NUTS(min_example)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
)
mcmc.run(rng_key, data)
mcmc.print_summary()
posterior_samples = mcmc.get_samples()
leads to:
Error
RuntimeError: Cannot find valid initial parameters. Please check your model again.
I can also provide the full call stack.
Trace
trace = numpyro.handlers.trace(numpyro.handlers.seed(min_example, rng_seed=0)).get_trace(data)
Error
File .../env/lib/python3.9/site-packages/numpyro/distributions/distribution.py:250, in Distribution.sample(self, key, sample_shape)
238 def sample(self, key, sample_shape=()):
239 """
240 Returns a sample from the distribution having shape given by
241 `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
(...)
248 :rtype: numpy.ndarray
249 """
--> 250 raise NotImplementedError
NotImplementedError:
This error seems to be due to the Reparam statements of the model. If I remove them, I get trace to work, but the model gives a warning for bad performance when using continous variables for circular parameters.
Trace output with removed circular reparam
OrderedDict([('phi_loc',
{'type': 'sample',
'name': 'phi_loc',
'fn': <numpyro.distributions.directional.VonMises at 0x7fe43c24ab20>,
'args': (),
'kwargs': {'rng_key': DeviceArray([2718843009, 1272950319], dtype=uint32),
'sample_shape': ()},
'value': DeviceArray(0.29627275, dtype=float32),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('psi_loc',
{'type': 'sample',
'name': 'psi_loc',
'fn': <numpyro.distributions.directional.VonMises at 0x7fe3f87c6880>,
'args': (),
'kwargs': {'rng_key': DeviceArray([1278412471, 2182328957], dtype=uint32),
'sample_shape': ()},
'value': DeviceArray(-0.03310037, dtype=float32),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('phi_conc',
{'type': 'sample',
'name': 'phi_conc',
'fn': <numpyro.distributions.continuous.Beta at 0x7fe4204706a0>,
'args': (),
'kwargs': {'rng_key': DeviceArray([4104543539, 3483300570], dtype=uint32),
'sample_shape': ()},
'value': DeviceArray(0.2734751, dtype=float32),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('psi_conc',
{'type': 'sample',
'name': 'psi_conc',
'fn': <numpyro.distributions.continuous.Beta at 0x7fe3f8457220>,
'args': (),
'kwargs': {'rng_key': DeviceArray([1194623263, 2038155241], dtype=uint32),
'sample_shape': ()},
'value': DeviceArray(0.6324667, dtype=float32),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('conc',
{'type': 'sample',
'name': 'conc',
'fn': <numpyro.distributions.continuous.HalfNormal at 0x7fe4205215b0>,
'args': (),
'kwargs': {'rng_key': DeviceArray([2205739499, 3850766070], dtype=uint32),
'sample_shape': ()},
'value': DeviceArray(0.6486334, dtype=float32),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('obs',
{'type': 'sample',
'name': 'obs',
'fn': <numpyro.distributions.directional.SineBivariateVonMises at 0x7fe3f824cd60>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': DeviceArray([[-0.7632046 , -2.5725281 ],
[-1.4259993 , 0.9532311 ],
[ 0.84518456, 0.53240395],
...,
[ 0.89588714, -0.06103492],
[ 0.50887036, -0.34725928],
[-1.3985238 , 0.53051925]], dtype=float32),
'scale': None,
'is_observed': True,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}})])
Now I am a bit confused as to where I find the probability. Sorry I never used trace so far.
Thank you!