AttributeError: 'function' object has no attribute 'log_prob'

Running into this error: AttributeError: 'function' object has no attribute 'log_prob'

I can’t find where it’s originating, can someone provide some thoughts ? This is my traceback :

    svi.step(CodeEvent[0].view(-1, 1, 1), CodeEvent[1].view(-1, 1, 1))
  File "/home/jeremy/anaconda3/envs/cibo/lib/python3.7/site-packages/pyro/infer/svi.py", line 99, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/jeremy/anaconda3/envs/cibo/lib/python3.7/site-packages/pyro/infer/trace_elbo.py", line 125, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
  File "/home/jeremy/anaconda3/envs/cibo/lib/python3.7/site-packages/pyro/infer/elbo.py", line 163, in _get_traces
    yield self._get_trace(model, guide, *args, **kwargs)
  File "/home/jeremy/anaconda3/envs/cibo/lib/python3.7/site-packages/pyro/infer/trace_elbo.py", line 52, in _get_trace
    "flat", self.max_plate_nesting, model, guide, *args, **kwargs)
  File "/home/jeremy/anaconda3/envs/cibo/lib/python3.7/site-packages/pyro/infer/enum.py", line 51, in get_importance_trace
    model_trace.compute_log_prob()
  File "/home/jeremy/anaconda3/envs/cibo/lib/python3.7/site-packages/pyro/poutine/trace_struct.py", line 163, in compute_log_prob
    log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
AttributeError: 'function' object has no attribute 'log_prob'

if you don’t post your code it’s hard to make sense of the error.

I guess you want to put a (stochastic) function instead of a distribution in a sample statement and run some inference involving computing log probability (e.g. SVI or HMC). If so, that will not work. You can use a stochatic function to generate sample. However, if you want to compute log prob, you have to implement a method log_prob for your stochastic function.

Apologies, it was a stupid mistake on my part. One of those mistakes where I stared at it and couldn’t find it for nearly an hour, posted here out of desperation, and then discovered the problem 15 minutes later:

This was my error:

pyro.sample("obs_val", dist.Bernoulli(outVal.view(y.shape[0])).independent, obs=y.view(y.shape[0]))

And the fix:

pyro.sample("obs_val", dist.Bernoulli(outVal.view(y.shape[0])).independent(1), obs=y.view(y.shape[0]))

…yea, simple and dumb mistake. Thanks for providing some thoughts here though!