How to set the scales and locs in pyro.param?

I have a 3D data stored as:

tensor([[ 0.4762, 5.2231, 0.2345],
[ 0.3257, 3.0552, -0.1863],
[ 3.1229, 3.8950, 0.4617],
…,
[10.4281, 0.1647, 0.0941],
[ 7.7341, 1.5497, -0.4066],
[ 7.3274, 1.9379, 0.2568]])

I would like to model the data as GMM, assuming the number of components is K.

How to set the scales and locs in pyro.param?

Finally, I think this 'll do:

    loc = torch.zeros(K,3)    
    scale = torch.diag(torch.ones(3))
    scale = torch.stack((scale, scale))    
    scales = pyro.param('scales', scale, 
    constraint=constraints.positive)
   locs = pyro.param('locs', loc)

I guess I would

data = ...
locs = pyro.param('locs', lambda: torch.zeros_like(data))
scales = pyro.param('scales', lambda: torch.ones_like(data),
                    constraint=constraints.positive)

Note the lambda: ... is a little cheaper since it only evaluates on the first SVI step.

Thank you. Indeed, lambda: ... is a little bit cheaper. There is however a size (or value) mismatch using torch.zeros_like(data) and torch.ones_like(data); causing ValueError: The parameter covariance_matrix has invalid values). Perhapse scales also are missing torch.diag. For convenience, here’s the full model:

@config_enumerate(default='parallel')
@poutine.broadcast
def model(data):
    # Global variables.
    weights = pyro.param('weights', torch.FloatTensor([0.5]), constraint=constraints.unit_interval)        
    scale_factor = 128 # or, scale_factor = max(torch.std(data, axis=1)); or; scale_factor = (data.var() / 2).sqrt()     
    number_of_samples = data.size(0)
    
    # This works
    # locs = data[np.random.randint(0, number_of_samples , K)] # randomly selecting locs/means from data, initializing to 0  leads to bad results
    # scales = scale_factor*torch.diag(torch.ones(data.shape[1] )) # using unit variance, think of data.shape[1] as num_features    
    # scales = torch.stack((scales,)*K)     
    # scales = pyro.param('scales', scales, constraint=constraints.positive)    
    # locs = pyro.param('locs', locs)    
    
    # This also works; using lambda is a little bit faster (~ 5% faster)
    locs = pyro.param('locs', lambda: data[np.random.randint(0, number_of_samples , K)]) # randomly selecting locs/means from data
    scales = pyro.param('scales', lambda: torch.stack((scale_factor*torch.diag(torch.ones(data.shape[1] )),)*K),
                    constraint=constraints.positive) # using unit variance, think of data.shape[1] as num_features
        
    ## Not working, size mismatch 
    # locs = pyro.param('locs', lambda: torch.zeros_like(data))
    # scales = pyro.param('scales', lambda: torch.ones_like(data),
    #                 constraint=constraints.positive)
       

    with pyro.iarange('data', number_of_samples): # I'm assuming use_cuda=True (if it is None, but GPU is available)
        # Local variables.
        assignment = pyro.sample('assignment', dist.Bernoulli(torch.ones(number_of_samples) * weights)).to(torch.int64)
        pyro.sample('obs', dist.MultivariateNormal(locs[assignment][0], scales[assignment]), obs=data)

Hi @Deeply, I believe you’ll want to use torch.diag_embed() rather than the error-prone torch.diag() to construct batched diagonal matrices.

1 Like