Hi there! I have a small question about how batched variables are dealt with when calling pyro.deterministic
. For example, I present the toy example, which samples a batch of variance and correlation, and calculates the sampled covariance. I believe this example is reproduceable:
@config_enumerate
def toy_model(batch_size=100):
with pyro.plate("component", batch_size):
# component of prior for covariance
theta = pyro.sample("theta", Chi2(df=torch.ones(d)*(d+1)).to_event(1))
omega = pyro.sample('omega', LKJCholesky(d, concentration=1))
Omega = pyro.deterministic("Omega", torch.bmm(theta.sqrt().diag_embed(), omega))
trace = poutine.trace(toy_model).get_trace()
print(trace.format_shapes())
Running the code to check the shapes of the variable, we have
Trace Shapes:
Param Sites:
Sample Sites:
component dist |
value 100 |
nu dist 100 |
value 100 |
theta dist 100 | 2
value 100 | 2
omega dist 100 | 2 2
value 100 | 2 2
Omega dist 100 | 100 2 2
value | 100 2 2
I was expecting the shape of Omega
will be the same as omega
. However, there seems to be a duplicate of batch size. So I was wondering, was I not using the deterministic
primitive correctly? How should I deal with this? (Currently I’m using a Normal distribution with very small variance, which seems to be immune to the problem above, despite some small noise is introduced.) Thanks for any advice!