We try to train the following model and print out the validation loss during training. Unfortunately we were not able to get this to work unless we have a training and a test-dataset of exactly the same size. We would optimally want to aim for an 80%/20% test-train-split. Is there a smart way to do this. Does anyone maybe have a code example?

Our Model:

```
class FA(PyroModule):
def __init__(self, train_data, test_data, n_features1, n_features2, K):
"""
Args:
Y: Tensor (Samples x Features)
K: Number of Latent Factors
"""
super().__init__()
pyro.clear_param_store()
# data
self.num_features1 = n_features1
self.num_features2 = n_features2
self.Y1 = train_data[:,:n_features1]
self.Y2 = train_data[:,n_features1:]
self.K = K
self.train_data = train_data
self.test_data = test_data
self.test_Y1 = test_data[:,:n_features1]
self.test_Y2 = test_data[:,n_features1:]
self.num_samples = self.Y1.shape[0]
self.sample_plate = pyro.plate("sample", self.num_samples)
self.feature_plate1 = pyro.plate("feature1", self.num_features1)
self.feature_plate2 = pyro.plate("feature2", self.num_features2)
self.latent_factor_plate = pyro.plate("latent factors", self.K)
print(self.test_Y1.shape)
print(self.test_Y2.shape)
def model(self, Y1, Y2):
"""
how to generate a matrix
"""
with self.latent_factor_plate:
with self.feature_plate1:
# sample weight matrix with Normal prior distribution
W1 = pyro.sample("W1", pyro.distributions.Normal(0., 1.))
with self.feature_plate2:
# sample weight matrix with Normal prior distribution
W2 = pyro.sample("W2", pyro.distributions.Normal(0., 1.))
with self.sample_plate:
# sample factor matrix with Normal prior distribution
Z = pyro.sample("Z", pyro.distributions.Normal(0., 1.))
# estimate for Y
Y1_hat = torch.matmul(Z, W1.t())
Y2_hat = torch.matmul(Z, W2.t())
with pyro.plate("feature1_", Y1.shape[1]), pyro.plate("sample_", Y1.shape[0]):
# masking the NA values such that they are not considered in the distributions
obs_mask = torch.ones_like(Y1, dtype=torch.bool)
if data is not None:
obs_mask = torch.logical_not(torch.isnan(Y1))
with pyro.poutine.mask(mask=obs_mask):
if data is not None:
# a valid value for the NAs has to be defined even though these samples will be ignored later
Y1 = torch.nan_to_num(Y1, nan=0)
# sample scale parameter for each feature-sample pair with LogNormal prior (has to be positive)
scale = pyro.sample("scale", pyro.distributions.LogNormal(0., 1.))
# compare sampled estimation to the true observation Y
pyro.sample("obs1", pyro.distributions.Normal(Y1_hat, scale), obs=Y1)
with pyro.plate("feature2_", Y2.shape[1]), pyro.plate("sample2_", Y2.shape[0]):
# masking the NA values such that they are not considered in the distributions
obs_mask = torch.ones_like(Y2, dtype=torch.bool)
if data is not None:
obs_mask = torch.logical_not(torch.isnan(Y2))
with pyro.poutine.mask(mask=obs_mask):
if data is not None:
# a valid value for the NAs has to be defined even though these samples will be ignored later
Y2 = torch.nan_to_num(Y2, nan=0)
# sample scale parameter for each feature-sample pair with LogNormal prior (has to be positive)
scale = pyro.sample("scale2", pyro.distributions.LogNormal(0., 1.))
# compare sampled estimation to the true observation Y
pyro.sample("obs2", pyro.distributions.Normal(Y2_hat, scale), obs=Y2)
def train(self):
# set training parameters
optimizer = pyro.optim.Adam({"lr": 0.02})
elbo = Trace_ELBO()
guide = autoguide.AutoDelta(self.model)
# initialize stochastic variational inference
svi = SVI(
model = self.model,
guide = guide,
optim = optimizer,
loss = elbo
)
num_iterations = 2000
train_loss = []
test_loss = []
for j in range(num_iterations):
#for j in enumerate(self.train_dataloader):
# calculate the loss and take a gradient step
loss = svi.step(self.Y1.T, self.Y2)
train_loss.append(loss/self.Y1.shape[0])
# test_loss.append(elbo.loss(self.model, guide, test_data))
if j % 200 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss / self.Y1.shape[0]))
with torch.no_grad(): # for logging only
train_loss2 = elbo.loss(self.model, guide, self.Y1, self.Y2) # or average over batch_loss
test_loss = elbo.loss(self.model, guide, self.test_Y1, self.test_Y2)
print(train_loss2, test_loss)
# Obtain maximum a posteriori estimates for W and Z
map_estimates = guide(Y)
return train_loss, map_estimates
```

Error we get:

```
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[70], line 2
1 FA_model = FA(train_data, test_data,20, 30, 5)
----> 2 losses, estimates = FA_model.train()
Cell In[69], line 114
112 with torch.no_grad(): # for logging only
113 train_loss2 = elbo.loss(self.model, guide, self.Y1, self.Y2) # or average over batch_loss
--> 114 test_loss = elbo.loss(self.model, guide, self.test_Y1, self.test_Y2)
115 print(train_loss2, test_loss)
117 # Obtain maximum a posteriori estimates for W and Z
File ~/.local/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:72, in Trace_ELBO.loss(self, model, guide, *args, **kwargs)
65 """
66 :returns: returns an estimate of the ELBO
67 :rtype: float
68
69 Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
70 """
71 elbo = 0.0
---> 72 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
73 elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(
74 guide_trace.log_prob_sum()
75 )
76 elbo += elbo_particle / self.num_particles
File ~/.local/lib/python3.8/site-packages/pyro/infer/elbo.py:237, in ELBO._get_traces(self, model, guide, args, kwargs)
235 else:
236 for i in range(self.num_particles):
--> 237 yield self._get_trace(model, guide, args, kwargs)
File ~/.local/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53 """
54 Returns a single trace from the guide, and the model that is run
55 against it.
56 """
---> 57 model_trace, guide_trace = get_importance_trace(
58 "flat", self.max_plate_nesting, model, guide, args, kwargs
59 )
60 if is_validation_enabled():
61 check_if_enumerated(guide_trace)
File ~/.local/lib/python3.8/site-packages/pyro/infer/enum.py:75, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
72 guide_trace = prune_subsample_sites(guide_trace)
73 model_trace = prune_subsample_sites(model_trace)
---> 75 model_trace.compute_log_prob()
76 guide_trace.compute_score_parts()
77 if is_validation_enabled():
File ~/.local/lib/python3.8/site-packages/pyro/poutine/trace_struct.py:276, in Trace.compute_log_prob(self, site_filter)
270 raise ValueError(
271 "Error while computing log_prob at site '{}':\n{}\n{}".format(
272 name, exc_value, shapes
273 )
274 ).with_traceback(traceback) from e
275 site["unscaled_log_prob"] = log_p
--> 276 log_p = scale_and_mask(log_p, site["scale"], site["mask"])
277 site["log_prob"] = log_p
278 site["log_prob_sum"] = log_p.sum()
File ~/.local/lib/python3.8/site-packages/pyro/distributions/util.py:328, in scale_and_mask(tensor, scale, mask)
326 if mask is False:
327 return torch.zeros_like(tensor)
--> 328 return torch.where(mask, tensor * scale, tensor.new_zeros(()))
RuntimeError: The size of tensor a (80) must match the size of tensor b (20) at non-singleton dimension 0
```

Thank you, any help is greatly appreciated.