Inferring discrete variables with funsor HMM

I am trying to extract the discrete hidden states from the funsor HMM example; but am getting an odd error when using this pattern to infer discrete sites. My model exactly matches the example posted here and I added this chunk of code to the end of main() which is supposed to extract the hidden states

guide_trace = handlers.trace(guide).get_trace(sequences, lengths)
trained_model = handlers.replay(model, trace=guide_trace)

inferred_model = infer.infer_discrete(
    trained_model, temperature=0,
    first_available_dim=first_available_dim)
trace = handlers.trace(inferred_model).get_trace(sequences, lengths)

The error arises when .get_trace() is invoked and its downstream calling of forward_backward() in adjoint.py. This error occurs no matter which model structure I pick within the funsor hmm example (i.e. it’s an issue pertaining to all funsor HMM’s, not just model_7() with vectorized time dimension). Also, this routine to infer discrete sites works just fine when I’m analyzing Bach Chorales using the standard pyro HMM example.

Here’s the full error trace:

<ipython-input-4-43327c88dced> in main(args)
    162         trained_model, temperature=0,
    163         first_available_dim=first_available_dim)
--> 164     trace = handlers.trace(inferred_model).get_trace(sequences, lengths)
    165 
    166     return trace, sequences, lengths

/juno/work/venv2/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    196         Calls this poutine and returns its trace instead of the function's return value.
    197         """
--> 198         self(*args, **kwargs)
    199         return self.msngr.get_trace()

/juno/work/venv2/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    172             )
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:
    176                 exc_type, exc_value, traceback = sys.exc_info()

/juno/work/venv2/lib/python3.7/site-packages/pyro/contrib/funsor/infer/discrete.py in _sample_posterior(model, first_available_dim, temperature, *args, **kwargs)
     44 
     45     with approx:
---> 46         approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)
     47 
     48     # construct a result trace to replay against the model

/juno/work/venv2/lib/python3.7/site-packages/funsor/adjoint.py in adjoint(sum_op, bin_op, expr)
    140 
    141 def adjoint(sum_op, bin_op, expr):
--> 142     forward, backward = forward_backward(sum_op, bin_op, expr)
    143     return backward
    144 

/juno/work/venv2/lib/python3.7/site-packages/funsor/adjoint.py in forward_backward(sum_op, bin_op, expr, batch_vars)
    135         # TODO fix traversal order in AdjointTape instead of using stack_reinterpret
    136         forward = stack_reinterpret(expr)
--> 137     backward = tape.adjoint(sum_op, bin_op, forward, batch_vars=batch_vars)
    138     return forward, backward
    139 

/juno/work/venv2/lib/python3.7/site-packages/funsor/adjoint.py in adjoint(self, sum_op, bin_op, root, targets, batch_vars)
    113                 self._eager_to_lazy[output] = lazy_output
    114 
--> 115             in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs)
    116             for v, adjv in in_adjs:
    117                 # Marginalize out message variables that don't appear in recipients.

/juno/work/venv2/lib/python3.7/site-packages/funsor/registry.py in __call__(self, key, *args)
    104 
    105     def __call__(self, key, *args):
--> 106         return self[key](*args)
    107 
    108     def dispatch(self, key, *args):

/juno/work/venv2/lib/python3.7/site-packages/funsor/registry.py in __call__(self, *args)
     61 
     62     def __call__(self, *args):
---> 63         return self.partial_call(*args)(*args)
     64 
     65 

/juno/work/venv2/lib/python3.7/site-packages/funsor/adjoint.py in adjoint_contract_generic(adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, terms)
    215     adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, terms
    216 ):
--> 217     assert len(terms) == 1 or len(terms) == 2
    218     return adjoint_ops(
    219         Contraction,

AssertionError: 

What versions of Pyro, PyTorch and Funsor are you using? Are you using a device other than a CPU?

I’m using a CPU on a linux cluster with these versions:

funsor==0.4.3
pyro-api==0.1.2
pyro-ppl==1.8.1
torch==1.11.0

Update: I just upgraded to pyro-ppl==1.8.2 and torch==1.12.1 since I didn’t have the most recent versions but re-running with these upgrades still produced the same error as my initial post.

I see, it looks like a longstanding bug in pyro.contrib.funsor - infer_discrete doesn’t interact correctly with replay. See this unit test for a workaround using condition instead of replay.

Would you mind opening a bug report issue on the Pyro GitHub repo with the information you posted here?

1 Like

@eb8680_2 I just opened an issue here. Note that the workaround using condition instead of replay continued to produce the same error for the final .get_trace() call. Here is my code snippet for the attempted workaround:

# get trace of discrete params
guide_trace = handlers.trace(guide).get_trace(sequences, lengths)
guide_data = {
    name: site["value"]
    for name, site in guide_trace.nodes.items()
    if site["type"] == "sample"
}

# MAP estimate discretes, conditioned on posterior sampled continous latents.
actual_trace = handlers.trace(
    infer.infer_discrete(
        handlers.condition(infer.config_enumerate(model), guide_data),
        temperature=0,
    )
).get_trace(sequences, lengths)
1 Like