Lambda function in a param statement

Hi,

Why and when would one use a lambda function as an argument to a param() function? For example, one finds

@poutine.broadcast
def guide(args, observations):
    # Initialize states randomly from the prior.
    states_loc = pyro.param("states_loc", lambda: torch.randn(args.max_num_objects, 2))
    states_scale = pyro.param("states_scale",
                              lambda: torch.ones(states_loc.shape) * args.emission_noise_scale,
                              constraint=constraints.positive)

Tracking an Unknown Number of Objects in the docs. Why not simply write:

    states_loc = pyro.param("states_loc", torch.randn(args.max_num_objects, 2))

Isn’t the second argument of pyro.param supposed to be the initial value the parameter is set to?

Thanks.

Gordon

Isn’t the second argument of pyro.param supposed to be the initial value the parameter is set to?

yes the lambda is simply a small optimization (for expensive tensors) that prevents reinitializing the tensor if it already exists in the param store. the lambda is only evaluated if the param is not initialized.