How to define Model for NUTS/HMC when samples from a posterior over Network parameters are required

Hi,

I am new to pyro, so apologies if this seems foolish. So I have this problem setup, wherein I have a Neural Network that can be used in conjunction with the Cross-Entropy loss to give us the log likelihood of the observed data; given the parameters of the Network. Also, we impose a Standard Normal Prior over the parameters of the Network. The goal is to get samples from the posterior over the Network parameters \theta given the observed data. From Bayes Rule

P(\theta \vert Data) \propto P(Data \vert \theta) P(\theta)

If I were able to pass in a potential function to the HMC/NUTS samples, this would be pretty straight forward. Since I have access to the likelihood and prior models. But this fails with the following error

ValueError: Must provide valid initial parameters to begin sampling when using `potential_fn` in HMC/NUTS kernel.

I looked through this post on the forum which suggested to provide a model as the input to the NUTS/HMC sampler. I am confused with regards to how to specify this using model. Also, I have written my code such that the Pytorch Network has no internal parameters, and given some data and the flattened network parameters, it performs a forward pass (appropriately splitting the flattened tensor into all the required parameters of the network, in a sequential manner). My current thought process looks like follows

In the code below, net is the Network, dist is a class that computes the log likelihood of the observed data given the parameterized network.

def probabilistic_model(net, dist, data_inputs, data_targets):
    ## net.num_params is the number of parameters which this network expects, dtype: int
    prior_dist = dist.Normal(0., 1.).expand([net.num_params]).to_event(1)
    params = pyro.sample("params", prior_dist)
    log_prior = prior_dist.log_prob(params).sum()
    data_preds = net(data_inputs, params)
    log_likelihood = dist.log_prob(data_preds, data_targets)
    log_unnorm_posterior = log_likelihood + log_prior
    ........ What to have here ? ...............

How do i structure my model, so that I would be able to sample from the posterior over the Network parameters. I am not returning the log_unnorm_posterior (or its exponentiated version), since at no point in this code have i conditioned on the Data. I am assuming, it should be similar to the following snippet I saw in a Pyro tutorial, with the conditioning and the distribution changed

return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

Do let me know if this is the right way to do it, or any mistakes I am making with the current formulation?

Thanks