Hi there, fairly new
pyro user here!
So far, I’ve created a
class SurvivalModel(PyroModule): def __init__(self, in_features, out_features, subsample_size=64): super().__init__() self.linear = PyroModule[nn.Linear](in_features, out_features) self.linear.weight = PyroSample( pyro.distributions.Normal(0., 100.).expand([out_features, in_features]).to_event(2) ) self.linear.bias = PyroSample( pyro.distributions.Normal(365., 100.).expand([out_features]).to_event(1) ) self.subsample_size = subsample_size def forward(self, x, y=None, truncation_label=None): mean = self.linear(x).squeeze(-1) with pyro.plate("data", size=x.shape, subsample_size=self.subsample_size) as ind: ...
In the above, I’ve specified a
However, when I was trying to generate samples from the posterior predictive using
Predictive, I started to run into problems with the following code:
predictive = Predictive( model, guide=guide, num_samples=800, return_sites=("linear.weight", "obs", "_RETURN") ) samples = predictive( x=X_test, y=None, truncation_label=None )
In the above, the model was still subsampling from
X_test instead of using the entire
X_test. Am I doing something wrong in my approach?
Any feedback would be appreciated!