# Batch shape when calling pyro.deterministic

Hi there! I have a small question about how batched variables are dealt with when calling `pyro.deterministic`. For example, I present the toy example, which samples a batch of variance and correlation, and calculates the sampled covariance. I believe this example is reproduceable:

``````@config_enumerate
def toy_model(batch_size=100):
with pyro.plate("component", batch_size):
# component of prior for covariance
theta = pyro.sample("theta", Chi2(df=torch.ones(d)*(d+1)).to_event(1))
omega = pyro.sample('omega', LKJCholesky(d, concentration=1))
Omega = pyro.deterministic("Omega", torch.bmm(theta.sqrt().diag_embed(), omega))

trace = poutine.trace(toy_model).get_trace()
print(trace.format_shapes())
``````

Running the code to check the shapes of the variable, we have

`````` Trace Shapes:
Param Sites:
Sample Sites:
component dist     |
value 100 |
nu dist 100 |
value 100 |
theta dist 100 |   2
value 100 |   2
omega dist 100 |   2 2
value 100 |   2 2
Omega dist 100 | 100 2 2
value     | 100 2 2
``````

I was expecting the shape of `Omega` will be the same as `omega`. However, there seems to be a duplicate of batch size. So I was wondering, was I not using the `deterministic` primitive correctly? How should I deal with this? (Currently I’m using a Normal distribution with very small variance, which seems to be immune to the problem above, despite some small noise is introduced.) Thanks for any advice!

Hi @Evan. Can you check what is the shape of `torch.bmm(theta.sqrt().diag_embed(), omega)`? And what do you expect it to be?

I checked that with `print`, which showed it is of size `torch.Size([100, 2, 2])` (just as expected).

I thought perhaps `pyro.deterministic()` is treating this tensor as one event, and therefore adding another dim with size of `batch_size`? However, I haven’t figured out how to solve this yet.

(By the way, this also happens to other examples, likel when I’m using `pyro.deterministic` to record a set of variables which are linked to a set of values through discrete index. Also got an extra dim)

It looks like pyro.deterministic sets `event_dim` to `value.ndim` by default. What if you set `event_dim=2`?

Oh I see! It works!

`````` Sample Sites:
component dist     |
value 100 |
theta dist 100 | 2
value 100 | 2
omega dist 100 | 2 2
value 100 | 2 2
Omega dist 100 | 2 2
value 100 | 2 2
``````

Thanks a lot!
(I think I should be more careful reading the document lol )

1 Like

Oops, sorry but, when I tried out `pyro.deterministic()` with indexed values, I still got into troubles at inference time:

``````w = pyro.deterministic("w", pc[u], event_dim=1)
``````

It was of shape `tensor.shape([50,2])`, but later during enumeration it turns out to be `torch.Size([9, 1, 2])` and caused `IndexError`. I also tried

``````w = pyro.sample("w", MultivariateNormal(pc[u], noise_ub * torch.eye(d)))
``````

which is simply a noisy version of the program above (, I thought), with the desired values being the mean of a MultivariateNormal. This alternative seems to work. But I’m confused about the difference: what’s causing the difference between `deterministic` and `sample`, and how should I handle this?

Can you show your code and the full error message?

Sure! (Thanks so much for your patience~) Here’s my model:

``````T = 50
T_pc = 9
def mix_weights(beta):
beta1m_cumprod = (1 - beta).cumprod(-1)

@config_enumerate
def model(data=None, gamma=0.1, alpha=0.1, noise_ub=0.001):
alpha_mu = pyro.param("alpha_mu", lambda: Gamma(1, 1).sample(), constraint=constraints.positive)
alpha_w = pyro.param("alpha_w", lambda: Gamma(1, 1).sample(), constraint=constraints.positive)
tau = pyro.param("tau", lambda: Gamma(1, 1).sample(), constraint=constraints.positive)

with pyro.plate("sticks", T-1):
beta = pyro.sample("beta", Beta(1, gamma))
with pyro.plate("PC_sticks", T_pc-1):
beta_ = pyro.sample("beta_", Beta(1, alpha))
with pyro.plate("PCs", T_pc):
pc = pyro.sample("pc", Normal(torch.zeros(d), 1/alpha_w.unsqueeze(-1)).to_event(1))

with pyro.plate("component", T) as idx:
u = pyro.sample("u", Categorical(mix_weights(beta_)), infer={'enumerate': 'parallel'})
mu = pyro.sample("mu", Normal(torch.zeros(d), 1/alpha_mu.unsqueeze(-1)).to_event(1))
w = pyro.deterministic("w", pc[u], event_dim=1)

with pyro.plate("data", N) as idx:
z = pyro.sample("z", Categorical(mix_weights(beta)), infer={'enumerate': 'parallel'})
pyro.sample(
"obs",
MultivariateNormal(
mu[z],
precision_matrix=tau * torch.eye(d) + Vindex(w)[z].unsqueeze(-1)*Vindex(w)[z].unsqueeze(-2)),
obs=data)
``````

It is a CRP mixture model, whose covariance is modeled with the PPCA framework. In other words, its covariance is approximated with isotropic noise plus a low-rank matrix (here is simply rank 1, with only one vector, with normal prior). The error is like

``````     25     pyro.sample(
26         "obs",
27         MultivariateNormal(
28             mu[z],
---> 29             precision_matrix=tau * torch.eye(d) + Vindex(w)[z].unsqueeze(-1)*Vindex(w)[z].unsqueeze(-2)),
30         obs=data)
IndexError: index 9 is out of bounds for dimension 0 with size 9
``````

And I found the shape issue with `w` mentioned before: it changes when using `svi`, with extra enum dimensions.

If you print out the shapes of `w` and `z` you will probably find out why you get `IndexError`. Looks like there is lots of math in your code, just be careful to make sure that shapes are correct and work with enumeration.

Yes, I just don’t understand how enumeration is handled in downstream sites, as in here `MultivariateNormal` seems to work well with enumerated `loc` parameter, while `deterministic` seems to simply take the enum shape as batch shape.

What is the shape of `w` when you use `MultivariateNormal` (with enumeration)?

`deterministic` should just return the value of `pc[u]`.

Can you clarify this more?

I used `print` to check the shapes:

• when I directly print `pc[u].shape`, I got `torch.Size([9, 1, 2]) `
• When using `MultivariateNormal` with `pc[u]` being its loc parameter, I got a `torch.Size([50, 2]) ` tensor, with the shape at dim 0 consistent with the claim of the plate.
• When using `deterministic`, I find the shape of the variable identical to that of `pc[u]`, being of shape `torch.Size([9, 1, 2])`

Setting `event_dim=1` helps recognize the last dimension as `event_shape`, while I don’t know how to properly handle the other dimensions.

I think the difference in the shape must be due to the fact that when you use `pyro.sample` the value of `w` is sampled by the guide (which I suppose has the shape of `torch.Size([50, 2])`) and then replayed into the model. `MultivariateNormal(pc[u], ...)` distribution in the model is only used to calculate the log density. If you print out the trace shapes you should see that `w` has `dist` shape of `MultivariateNormal(pc[u], ...)` (something like `(9, 1, 2)`) and `value` shape of `(50, 2)` (note that this value is sampled by the guide), and `log_prob` shape which is the broadcast of the two.

Oh thanks! I wasn’t fully aware of these differences. I think I’m getting an intuitive understanding now.
Just one more question about this: how can I perform the similar sample-then-replay process with `deterministic` sites? For now, I’m using a noisy version of this expression (which is somewhat okay, at least it runs without errors, but not quite elegant I think)

``````w = pyro.sample("w", MultivariateNormal(pc[u], noise_ub * torch.eye(d)))
``````

(I’ve tried this out, and this will not cause much trouble to my model: it can eventually converge to a good posterior. But I thought using deterministic will be more loyal to my original claim about the model.)

Maybe you can expand (repeat) `pc[u]` along `dim=-2` so that it has the shape of `(9, 50, 2)`. Then make sure when you use `Vindex(w)[z]` you are vindexing along that dim (since it is not the first dim anymore).

You can also just forget that you are using `deterministic` since it is not affecting inference in any way.