Hi Everyone,
I was wondering if you could help me with my issue below. I am preparing a tutorial on Pyro, and wanted to demonstrate a Logistic regression example with hyperpriors for the purpose of introducing how PyroModule
works.
However, I can’t seem to get it to behave properly. It seems to ignore the variable ‘lam’ in the trace. Using autoguide also seems to ignore the site. I suspect I am missing something.
I have this model working without using the PyroModule approach , but I would like it both ways for comparison purposes. What am I missing? I have tried specifying ‘lam’ with PyroSample also.
from pyro.distributions import Normal, Bernoulli, HalfCauchy
from pyro.nn import PyroModule, PyroSample
from torch.nn.modules import Linear
from torch.nn import ELU
class Model(pyro.nn.PyroModule):
def __init__(self, in_dim, w_prior=1.):
super().__init__()
self.l1 = PyroModule[Linear](in_dim, 1)
self.lam = pyro.sample('lam', HalfCauchy(scale=1.))
self.l1.weight = PyroSample(Normal(0., self.lam).expand([1, in_dim]).to_event(2))
self.l1.bias = PyroSample(Normal(0., 10.))
def forward(self, x, y_obs = None):
alpha = self.l1(x)
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", Bernoulli(logits = alpha.reshape(-1)), obs = y_obs)
return obs