Hi,
Why and when would one use a lambda function as an argument to a param() function? For example, one finds
@poutine.broadcast
def guide(args, observations):
# Initialize states randomly from the prior.
states_loc = pyro.param("states_loc", lambda: torch.randn(args.max_num_objects, 2))
states_scale = pyro.param("states_scale",
lambda: torch.ones(states_loc.shape) * args.emission_noise_scale,
constraint=constraints.positive)
Tracking an Unknown Number of Objects in the docs. Why not simply write:
states_loc = pyro.param("states_loc", torch.randn(args.max_num_objects, 2))
Isn’t the second argument of pyro.param supposed to be the initial value the parameter is set to?
Thanks.
Gordon