Thanks so much for your response @fritzo, oddly I seem to be getting a key error in some internal code for the compute_marginals(), any idea if this is expected:
def predict(obs, cluster_centers, cluster_stds):
with pyro.plate("num_obs", len(obs)):
cls = pyro.sample("class_idx", Categorical(prior_weights))
pyro.sample("obs", Normal(cluster_centers[cls], cluster_stds[cls], obs=obs)
return cls
# this works
losser.loss(predict,x, **params)
# this throws an exception
losser.compute_marginals(predict, guide_for_inference,x, **params)
throws the exception
KeyError Traceback (most recent call last)
<ipython-input-15-ae6e2d72ce18> in <module>
14 losser.loss(multi_choice_single_sign_model, guide_for_inference,x.unsqueeze(0), **params, **model_hyperparameters_int)
15 # this throws an exception
---> 16 losser.sample_posterior(multi_choice_single_sign_model, guide_for_inference,x.unsqueeze(0), **params, **model_hyperparameters_int)
/anaconda3/lib/python3.8/site-packages/pyro/infer/traceenum_elbo.py in sample_posterior(self, model, guide, *args, **kwargs)
455 # TODO replace BackwardSample with torch_sample backend to ubersum
456 with BackwardSampleMessenger(model_trace, guide_trace):
--> 457 return poutine.replay(model, trace=guide_trace)(*args, **kwargs)
458
459
/anaconda3/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
/anaconda3/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
<ipython-input-2-09a2e5e01f78> in multi_choice_single_sign_model(data, transition_matrix, emission_logits, process_hands, hand_strategy, **kwargs)
77
78 with pyro.plate("datapoints", num_datapoints):
---> 79 sign_idx = pyro.sample("sign",
80 distributions.Categorical(maybe_cuda(torch.tensor([1 / num_signs for _ in range(num_signs)]))))
81 state = maybe_cuda(torch.zeros(num_datapoints).long())
/anaconda3/lib/python3.8/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
111 msg["is_observed"] = True
112 # apply the stack and return its return value
--> 113 apply_stack(msg)
114 return msg["value"]
115
/anaconda3/lib/python3.8/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
191 pointer = pointer + 1
192
--> 193 frame._process_message(msg)
194
195 if msg["stop"]:
/anaconda3/lib/python3.8/site-packages/pyro/poutine/messenger.py in _process_message(self, msg)
137 method = getattr(self, "_pyro_{}".format(msg["type"]), None)
138 if method is not None:
--> 139 return method(msg)
140 return None
141
/anaconda3/lib/python3.8/site-packages/pyro/infer/traceenum_elbo.py in _pyro_sample(self, msg)
245 with shared_intermediates(self.cache):
246 ordinal = _find_ordinal(self.enum_trace, msg)
--> 247 logits = contract_to_tensor(self.log_factors, self.sum_dims,
248 target_ordinal=ordinal, target_dims={enum_symbol},
249 cache=self.cache)
/anaconda3/lib/python3.8/site-packages/pyro/ops/contract.py in contract_to_tensor(tensor_tree, sum_dims, target_ordinal, target_dims, cache, ring)
259 term = ring.sumproduct(contracted_terms, set())
260 assert sum_dims.intersection(term._pyro_dims) <= target_dims
--> 261 return ring.broadcast(term, target_ordinal)
262
263
/anaconda3/lib/python3.8/site-packages/pyro/ops/rings.py in broadcast(self, term, ordinal)
78 else:
79 missing_shape = tuple(self._dim_to_size[dim] for dim in missing_dims)
---> 80 term = term.expand(missing_shape + term.shape)
81 dims = missing_dims + dims
82 self._cache[key] = term
/anaconda3/lib/python3.8/site-packages/pyro/ops/rings.py in <genexpr>(.0)
78 else:
79 missing_shape = tuple(self._dim_to_size[dim] for dim in missing_dims)
---> 80 term = term.expand(missing_shape + term.shape)
81 dims = missing_dims + dims
82 self._cache[key] = term
KeyError: 'a'
This only occurs when I’m trying to infer on a single datapoint (so there’s a pyro.plate("num_obs", 1")
), so my current workaround is to drop the plate if there’s only one observation