Hi there, fairly new pyro
user here!
So far, I’ve created a PyroModule
:
class SurvivalModel(PyroModule):
def __init__(self, in_features, out_features, subsample_size=64):
super().__init__()
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(
pyro.distributions.Normal(0., 100.).expand([out_features, in_features]).to_event(2)
)
self.linear.bias = PyroSample(
pyro.distributions.Normal(365., 100.).expand([out_features]).to_event(1)
)
self.subsample_size = subsample_size
def forward(self, x, y=None, truncation_label=None):
mean = self.linear(x).squeeze(-1)
with pyro.plate("data", size=x.shape[0], subsample_size=self.subsample_size) as ind:
...
In the above, I’ve specified a subsample_size
.
However, when I was trying to generate samples from the posterior predictive using Predictive
, I started to run into problems with the following code:
predictive = Predictive(
model,
guide=guide,
num_samples=800,
return_sites=("linear.weight", "obs", "_RETURN")
)
samples = predictive(
x=X_test,
y=None,
truncation_label=None
)
In the above, the model was still subsampling from X_test
instead of using the entire X_test
. Am I doing something wrong in my approach?
Any feedback would be appreciated!