I have a model that includes two GP kernels, as follows:
...
K_a = gp.kernels.RBF(input_dim=1)
K_a.lengthscale = PyroParam(dtensor(1.), constraint=constraints.interval(0.5, 5.0))
K_a.variance = PyroParam(dtensor(0.1), constraint=constraints.interval(0.001, 10.0))
cov_a = K_a(torch.arange(A, device=device)).contiguous()
with pyro.plate("ages", A):
f_tilde_a = pyro.sample("f_tilde_a", dist.Normal(dtensor(0.0), dtensor(1.0)))
f_age = pyro.deterministic(
"f_age", torch.linalg.cholesky(cov_a + torch.eye(A, device=device) * jitter) @ f_tilde_a.squeeze()
)
K_p = gp.kernels.RBF(input_dim=1)
K_p.lengthscale = PyroParam(dtensor(1.), constraint=constraints.interval(0.5, 5.0))
K_p.variance = PyroParam(dtensor(0.1), constraint=constraints.interval(0.001, 10.0))
with pyro.plate("players", P):
cov_p = K_p(torch.arange(S, device=device)).contiguous()
with pyro.plate("seasons", S):
f_tilde_p = pyro.sample("f_tilde_p", dist.Normal(dtensor(0.0), dtensor(1.0)))
f_stuff = pyro.deterministic(
"f_stuff", stuff_0 + torch.linalg.cholesky(cov_p + torch.eye(S, device=device) * jitter) @ f_tilde_p.squeeze()
)
However, when I fit the model and look at the parameter store, I only see one lengthscale and variance (with no indication as to which GP they belong to):
lengthscale tensor(0.5023, device='cuda:0', grad_fn=<AddBackward0>)
variance tensor(0.0010, device='cuda:0', grad_fn=<AddBackward0>)
AutoLowRankMultivariateNormal.loc Parameter containing:
tensor([-0.0150, -0.5806, -0.7582, ..., 0.0737, 0.0947, 0.0789],
device='cuda:0')
AutoLowRankMultivariateNormal.scale tensor([0.6778, 0.6832, 0.6049, ..., 0.4657, 0.4634, 0.4914], device='cuda:0')
AutoLowRankMultivariateNormal.cov_factor Parameter containing:
tensor([[-0.0356, 0.0041, -0.0027, ..., 0.0121, 0.0309, -0.0333],
[-0.0033, 0.0542, -0.0080, ..., -0.0382, 0.0399, 0.0459],
[-0.0117, 0.0198, -0.0069, ..., -0.0160, 0.0110, 0.0161],
...,
[ 0.0417, -0.1485, -0.0494, ..., 0.0274, -0.0070, 0.0321],
[-0.0530, -0.0257, 0.0172, ..., -0.0857, 0.1026, -0.0106],
[-0.0333, -0.0978, 0.0684, ..., -0.0060, 0.0424, 0.0953]],
device='cuda:0')
How do I look at the parameter estimates for both components?