My guess is, you can pass a dictionary like below
options = dict(dtype=input.dtype, device=input.device)
to sample statements like below within the model function
prior_loc = torch.zeros(batch_size, dim1, **options)
prior_scale = torch.ones(batch_size, dim1, **options)
zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))