Thank you. Indeed, lambda: ... is a little bit cheaper. There is however a size (or value) mismatch using torch.zeros_like(data) and torch.ones_like(data); causing ValueError: The parameter covariance_matrix has invalid values). Perhapse scales also are missing torch.diag. For convenience, here’s the full model:
@config_enumerate(default='parallel')
@poutine.broadcast
def model(data):
# Global variables.
weights = pyro.param('weights', torch.FloatTensor([0.5]), constraint=constraints.unit_interval)
scale_factor = 128 # or, scale_factor = max(torch.std(data, axis=1)); or; scale_factor = (data.var() / 2).sqrt()
number_of_samples = data.size(0)
# This works
# locs = data[np.random.randint(0, number_of_samples , K)] # randomly selecting locs/means from data, initializing to 0 leads to bad results
# scales = scale_factor*torch.diag(torch.ones(data.shape[1] )) # using unit variance, think of data.shape[1] as num_features
# scales = torch.stack((scales,)*K)
# scales = pyro.param('scales', scales, constraint=constraints.positive)
# locs = pyro.param('locs', locs)
# This also works; using lambda is a little bit faster (~ 5% faster)
locs = pyro.param('locs', lambda: data[np.random.randint(0, number_of_samples , K)]) # randomly selecting locs/means from data
scales = pyro.param('scales', lambda: torch.stack((scale_factor*torch.diag(torch.ones(data.shape[1] )),)*K),
constraint=constraints.positive) # using unit variance, think of data.shape[1] as num_features
## Not working, size mismatch
# locs = pyro.param('locs', lambda: torch.zeros_like(data))
# scales = pyro.param('scales', lambda: torch.ones_like(data),
# constraint=constraints.positive)
with pyro.iarange('data', number_of_samples): # I'm assuming use_cuda=True (if it is None, but GPU is available)
# Local variables.
assignment = pyro.sample('assignment', dist.Bernoulli(torch.ones(number_of_samples) * weights)).to(torch.int64)
pyro.sample('obs', dist.MultivariateNormal(locs[assignment][0], scales[assignment]), obs=data)