Here’s the full error trace:
<ipython-input-4-43327c88dced> in main(args)
162 trained_model, temperature=0,
163 first_available_dim=first_available_dim)
--> 164 trace = handlers.trace(inferred_model).get_trace(sequences, lengths)
165
166 return trace, sequences, lengths
/juno/work/venv2/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
196 Calls this poutine and returns its trace instead of the function's return value.
197 """
--> 198 self(*args, **kwargs)
199 return self.msngr.get_trace()
/juno/work/venv2/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
172 )
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176 exc_type, exc_value, traceback = sys.exc_info()
/juno/work/venv2/lib/python3.7/site-packages/pyro/contrib/funsor/infer/discrete.py in _sample_posterior(model, first_available_dim, temperature, *args, **kwargs)
44
45 with approx:
---> 46 approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)
47
48 # construct a result trace to replay against the model
/juno/work/venv2/lib/python3.7/site-packages/funsor/adjoint.py in adjoint(sum_op, bin_op, expr)
140
141 def adjoint(sum_op, bin_op, expr):
--> 142 forward, backward = forward_backward(sum_op, bin_op, expr)
143 return backward
144
/juno/work/venv2/lib/python3.7/site-packages/funsor/adjoint.py in forward_backward(sum_op, bin_op, expr, batch_vars)
135 # TODO fix traversal order in AdjointTape instead of using stack_reinterpret
136 forward = stack_reinterpret(expr)
--> 137 backward = tape.adjoint(sum_op, bin_op, forward, batch_vars=batch_vars)
138 return forward, backward
139
/juno/work/venv2/lib/python3.7/site-packages/funsor/adjoint.py in adjoint(self, sum_op, bin_op, root, targets, batch_vars)
113 self._eager_to_lazy[output] = lazy_output
114
--> 115 in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs)
116 for v, adjv in in_adjs:
117 # Marginalize out message variables that don't appear in recipients.
/juno/work/venv2/lib/python3.7/site-packages/funsor/registry.py in __call__(self, key, *args)
104
105 def __call__(self, key, *args):
--> 106 return self[key](*args)
107
108 def dispatch(self, key, *args):
/juno/work/venv2/lib/python3.7/site-packages/funsor/registry.py in __call__(self, *args)
61
62 def __call__(self, *args):
---> 63 return self.partial_call(*args)(*args)
64
65
/juno/work/venv2/lib/python3.7/site-packages/funsor/adjoint.py in adjoint_contract_generic(adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, terms)
215 adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, terms
216 ):
--> 217 assert len(terms) == 1 or len(terms) == 2
218 return adjoint_ops(
219 Contraction,
AssertionError: