Lambda function in a param statement


#1

Hi,

Why and when would one use a lambda function as an argument to a param() function? For example, one finds

@poutine.broadcast
def guide(args, observations):
    # Initialize states randomly from the prior.
    states_loc = pyro.param("states_loc", lambda: torch.randn(args.max_num_objects, 2))
    states_scale = pyro.param("states_scale",
                              lambda: torch.ones(states_loc.shape) * args.emission_noise_scale,
                              constraint=constraints.positive)

Tracking an Unknown Number of Objects in the docs. Why not simply write:

    states_loc = pyro.param("states_loc", torch.randn(args.max_num_objects, 2))

Isn’t the second argument of pyro.param supposed to be the initial value the parameter is set to?

Thanks.

Gordon


#2

Isn’t the second argument of pyro.param supposed to be the initial value the parameter is set to?

yes the lambda is simply a small optimization (for expensive tensors) that prevents reinitializing the tensor if it already exists in the param store. the lambda is only evaluated if the param is not initialized.