Strange behaviour of NUTS sampling with Neutra reparameterisation and Factor-based model

In short:

I am testing sampling through a normalising flow with the Neutra sampler, using
a model that employs a factor statement to include a custom likelihood. I’m
finding that the sampler gets stuck and does not update the parameters
correctly, whereas using a model that uses a torch.dist as the likelihood
handles this just fine.

More detailed:

The relevant functions are:

def general_torch_based_MVN_likelihood_function(
    mvn_configuration, x, y, register_info_dict=None
):
    # handle call counter
    handle_registry_increment(register_info_dict=register_info_dict)

    # Construct array of the arguments
    symbol_array = torch.zeros(2)
    symbol_array[0] = x
    symbol_array[1] = y

    # unpack data
    mu = torch.from_numpy(mvn_configuration["mean"].astype(np.float32))
    cov = torch.from_numpy(mvn_configuration["covariance"].astype(np.float32))

    # get dimension
    dim = cov.ndim

    # calculate some covariance related values
    inv_cov = torch.linalg.inv(cov)
    abs_det = torch.abs(torch.linalg.det(cov))

    # calculate MVN probability
    logpdf = torch.log(1 / (np.sqrt(2 * np.pi) ** (dim) * abs_det)) - 0.5 * (
        torch.inner((symbol_array - mu), torch.inner(inv_cov, (symbol_array - mu)))
    )

    return logpdf

Here general_torch_based_MVN_likelihood_function is my not-yet-configured
likelihood function.

def return_NewNormalPriorScaledTorchMVNFactorModel(
    mvn_configuration, register_info_dict=None
):

    specific_torch_based_MVN_likelihood_function = partial(
        general_torch_based_MVN_likelihood_function,
        mvn_configuration=mvn_configuration,
        register_info_dict=register_info_dict,
    )

    def NewNormalPriorScaledTorchMVNFactorModel():
        # Return scaled samples
        with pyro.poutine.scale(scale=1e-10):
            x = pyro.sample("x", dist.Normal(0, 1))
            y = pyro.sample("y", dist.Normal(0, 1))

        # Add likelihood to function
        model_factor = pyro.factor(
            "custom_likelihood", specific_torch_based_MVN_likelihood_function(x=x, y=y)
        )

    return NewNormalPriorScaledTorchMVNFactorModel

Here return_NewNormalPriorScaledTorchMVNFactorModel is a function that handles
configuring the likelihood and returning an actual model.

NewNormalPriorScaledTorchMVNFactorModel = (
    return_NewNormalPriorScaledTorchMVNFactorModel(
        mvn_configuration=mvn_distribution_specifications,
        register_info_dict=config["register_info_dict"],
    )
)

The model that I use is NewNormalPriorScaledTorchMVNFactorModel. The
mvn_configuration contains some arbitrary set mean and covariance matrix that I
can change for testing purposes like this. The register_info_dict allows me to
track the number of likelihood function calls, but is not relevant to this
problem.

When I use this model to sample directly with the NUTS sampler (i…e no
normalising flow and reparameterisation), the results are as expected,the
sampling behaves as intended, and recovers the desired distribution. I now want
to train a normalising flow and use the Neutra reparameterisation to sample with
the normalising flow.

This process has two steps:

  • training the normalising flow (using SVI)
  • sampling

The map-building works well. An abridged version of the output is:

[utility_functions.py:107 -            fit_guide ] 2023-03-10 12:22:46,146: [0] Elbo loss = 402.07
tensor([-0.2323,  0.4446], grad_fn=<CopySlices>)
tensor(-406.0912, grad_fn=<SubBackward0>)
tensor([ 0.8584, -0.0169], grad_fn=<CopySlices>)
tensor(-394.8972, grad_fn=<SubBackward0>)
tensor([ 1.3832, -0.0082], grad_fn=<CopySlices>)
tensor(-388.6822, grad_fn=<SubBackward0>)
tensor([1.9231, 0.9678], grad_fn=<CopySlices>)
tensor(-378.5714, grad_fn=<SubBackward0>)
tensor([-0.6429,  0.6282], grad_fn=<CopySlices>)
tensor(-410.3394, grad_fn=<SubBackward0>)
tensor([1.5488, 0.7696], grad_fn=<CopySlices>)
tensor(-383.6916, grad_fn=<SubBackward0>)
...
tensor(-297.9914, grad_fn=<SubBackward0>)
[utility_functions.py:107 -            fit_guide ] 2023-03-10 12:22:47,750: [100] Elbo loss = 297.70
tensor([7.7855, 6.1932], grad_fn=<CopySlices>)
...
tensor([39.9734, 20.2629], grad_fn=<CopySlices>)
tensor(-45.1703, grad_fn=<SubBackward0>)
tensor([39.9762, 16.7412], grad_fn=<CopySlices>)
tensor(-46.2139, grad_fn=<SubBackward0>)
tensor([39.3976, 20.7407], grad_fn=<CopySlices>)
tensor(-47.5574, grad_fn=<SubBackward0>)
tensor([40.0711, 20.9192], grad_fn=<CopySlices>)
tensor(-44.8575, grad_fn=<SubBackward0>)

Here we see the set of parameters (e.g. tensor([-0.2323, 0.4446], grad_fn=<CopySlices>)), and the log prob (e.g. tensor(-406.0912, grad_fn=<SubBackward0>)) change over time. The log prob increases, and the set
of parameters changes properly.

The sampling, however, seems not to work well:

Warmup:   1%|          | 1/110 [00:00,  8.17it/s, step size=5.75e+01, acc. prob=1.000]
tensor([40.1143, 20.8328], grad_fn=<CopySlices>)
tensor(-44.6704, grad_fn=<SubBackward0>)
tensor([40.1147, 20.8331], grad_fn=<CopySlices>)
tensor(-44.6688, grad_fn=<SubBackward0>)
tensor([40.1147, 20.8332], grad_fn=<CopySlices>)
...
Warmup:   3%|▎         | 3/110 [00:01,  2.26it/s, step size=1.56e+01, acc. prob=0.682]
tensor([40.1147, 20.9702], grad_fn=<CopySlices>)
tensor(-44.6934, grad_fn=<SubBackward0>)
tensor([40.1147, 20.9702], grad_fn=<CopySlices>)
tensor(-44.6934, grad_fn=<SubBackward0>)
tensor([40.1147, 20.9702], grad_fn=<CopySlices>)
...
Sample:  47%|████▋     | 52/110 [06:54,  9.53s/it, step size=3.56e+01, acc. prob=0.986]
tensor([40.1147, 20.9702], grad_fn=<CopySlices>)
tensor(-44.6934, grad_fn=<SubBackward0>)
tensor([40.1147, 20.9702], grad_fn=<CopySlices>)
tensor(-44.6934, grad_fn=<SubBackward0>)
tensor([40.1147, 20.9702], grad_fn=<CopySlices>)
tensor(-44.6934, grad_fn=<SubBackward0>)
tensor([40.1147, 20.9702], grad_fn=<CopySlices>)
  • Acceptance probability seems okay
  • Sampling takes way longer than in a “non factor-based” model (i.e. just using
    a torch dist)
  • The set of parameters is not properly updated (the first parameter is stuck
    since the start, and the second one does not seem to change much.
  • This will obviously not sample the posterior properly.

Some extra configuration that is relevant:

config["n_warmup"] = 10
config["sample_size"] = 100
config["num_flows_map"] = 1
config["learning_rate_map"] = 1e-2
config["num_steps_map"] = 200
config["guide_fitting_logging_steps"] = 100

When I use a non factor-based model (i.e torch dist), then sampling with
normalising flows and Neutra reparameterisation works fine.

My questions are:

  • Is there any obvious issue with my model / approach?
  • How can I find out w,ore about what is actually going wrong?
  • Why might using a factor statement change the behaviour of the sampling
    through reparameterisation?

This post is related to two earlier posts:

i don’t really follow how different your custom log_prob is from the torch dist log_prob so hard to comment.

a lot of these hyperparameters seem suspect. instead try something like

config[“n_warmup”] = 500
config[“sample_size”] = 2000
config[“num_flows_map”] = 1
config[“learning_rate_map”] = 1e-3
config[“num_steps_map”] = 2000

Hi Martin,

I don’t really follow how different your custom log_prob is from the torch dist log_prob so hard to comment.

My custom log_prob should give the same answer as a torch.MultivariateNormal. I do this to test whether the pyro.factor version of a model behaves the same (i.e. described the same posterior, samples with comparable performance) as one that has its likelihood described by a torch distribution. Using direct NUTS sampling, it seems that both approaches work the same way; i.e. they describe the same posterior, and they do this in a comparable time. Later the custom log_prob of course will be filled with some science-case dependent function.

a lot of these hyperparameters seem suspect

Right, these values are a bit dodgy indeed, although changing these to more sensible parameters does not change the situation.

Below an abridged version of a run with your suggestions (I changed num_steps_map to 4000 to ensure a good maping):

Training the map

[utility_functions.py:107 -            fit_guide ] 2023-03-12 11:04:18,980: [0] Elbo loss = 401.56
Parameters: tensor([-0.8701, -0.3759], grad_fn=<CopySlices>) logpdf: -417.0910949707031
Parameters: tensor([-0.1969, -0.8866], grad_fn=<CopySlices>) logpdf: -411.0481872558594
Parameters: tensor([-0.7899, -0.7184], grad_fn=<CopySlices>) logpdf: -417.52374267578125
Parameters: tensor([-0.5383,  0.4030], grad_fn=<CopySlices>) logpdf: -409.94989013671875
Parameters: tensor([-0.4523,  0.9954], grad_fn=<CopySlices>) logpdf: -406.62261962890625
Parameters: tensor([-0.6008, -0.6274], grad_fn=<CopySlices>) logpdf: -414.85101318359375
...
Parameters: tensor([37.9201, 21.0745], grad_fn=<CopySlices>) logpdf: -53.92456817626953
Parameters: tensor([38.2028, 20.7565], grad_fn=<CopySlices>) logpdf: -52.625579833984375
Parameters: tensor([38.1856, 20.9255], grad_fn=<CopySlices>) logpdf: -52.729034423828125
Parameters: tensor([37.7222, 21.0089], grad_fn=<CopySlices>) logpdf: -54.788536071777344
[utility_functions.py:107 -            fit_guide ] 2023-03-12 11:05:03,640: [2000] Elbo loss = 52.04
Parameters: tensor([38.3260, 21.0676], grad_fn=<CopySlices>) logpdf: -52.146934509277344
Parameters: tensor([38.3312, 21.2793], grad_fn=<CopySlices>) logpdf: -52.174278259277344
Parameters: tensor([34.5970, 20.0134], grad_fn=<CopySlices>) logpdf: -69.58824920654297
Parameters: tensor([38.4475, 18.6872], grad_fn=<CopySlices>) logpdf: -51.680320739746094
...
Parameters: tensor([57.8790, 19.2454], grad_fn=<CopySlices>) logpdf: -5.563545227050781
Parameters: tensor([60.6054, 22.9342], grad_fn=<CopySlices>) logpdf: -5.9543609619140625
Parameters: tensor([60.5216, 14.6580], grad_fn=<CopySlices>) logpdf: -7.937607765197754
[utility_functions.py:107 -            fit_guide ] 2023-03-12 11:05:38,309: [3900] Elbo loss = 3.08
Parameters: tensor([61.2317, 18.9709], grad_fn=<CopySlices>) logpdf: -5.314355850219727
Parameters: tensor([58.4747, 22.9687], grad_fn=<CopySlices>) logpdf: -6.1707377433776855

Sampling with NUTS through map:

Warmup:   0%|          | 0/2500 [00:00, ?it/s]Parameters: tensor([57.1514, 21.7731]) logpdf: -6.1825947761535645
Parameters: tensor([60.3045, 16.8009]) logpdf: -6.089438438415527
Parameters: tensor([60.3045, 16.8009], grad_fn=<CopySlices>) logpdf: -6.089438438415527
Parameters: tensor([60.3045, 16.8009], grad_fn=<CopySlices>) logpdf: -6.089438438415527
Parameters: tensor([60.3045, 16.8009]) logpdf: -6.089438438415527
Parameters: tensor([60.3045, 16.8009], grad_fn=<CopySlices>) logpdf: -6.089438438415527
Parameters: tensor([60.9706, 13.8637], grad_fn=<CopySlices>) logpdf: -8.916431427001953
Parameters: tensor([60.3045, 16.8009], grad_fn=<CopySlices>) logpdf: -6.089438438415527
Parameters: tensor([59.8281, 16.2846], grad_fn=<CopySlices>) logpdf: -6.440091609954834
Parameters: tensor([61.3999, 15.5183], grad_fn=<CopySlices>) logpdf: -7.2612762451171875
Parameters: tensor([58.1171, 19.0009], grad_fn=<CopySlices>) logpdf: -5.5111236572265625
Parameters: tensor([55.7143, 20.7165], grad_fn=<CopySlices>) logpdf: -6.944823741912842
Parameters: tensor([61.9483, 15.7273], grad_fn=<CopySlices>) logpdf: -7.261930465698242
Parameters: tensor([62.2521, 17.4760], grad_fn=<CopySlices>) logpdf: -6.201027870178223
Parameters: tensor([62.4404, 20.1765], grad_fn=<CopySlices>) logpdf: -5.655401229858398
Parameters: tensor([62.5722, 22.0135], grad_fn=<CopySlices>) logpdf: -6.123801231384277

Warmup:   0%|          | 1/2500 [00:00,  7.05it/s, step size=6.27e+00, acc. prob=0.925]Parameters: tensor([-33.6092,   5.2560], grad_fn=<CopySlices>) logpdf: -903.0630493164062
Parameters: tensor([62.3105, 18.4675], grad_fn=<CopySlices>) logpdf: -5.825458526611328
Parameters: tensor([61.9386, 13.4231], grad_fn=<CopySlices>) logpdf: -9.758113861083984
Parameters: tensor([60.4027, 20.9009], grad_fn=<CopySlices>) logpdf: -5.15412712097168
Parameters: tensor([53.8168, 22.0862], grad_fn=<CopySlices>) logpdf: -9.315200805664062
Parameters: tensor([59.9154, 21.5189], grad_fn=<CopySlices>) logpdf: -5.288169860839844
Parameters: tensor([60.7725, 19.9802], grad_fn=<CopySlices>) logpdf: -5.116472244262695
Parameters: tensor([61.0477, 18.9062], grad_fn=<CopySlices>) logpdf: -5.2861528396606445
Parameters: tensor([59.3019, 21.8560], grad_fn=<CopySlices>) logpdf: -5.449966907501221
Parameters: tensor([58.6254, 21.9962], grad_fn=<CopySlices>) logpdf: -5.644172191619873
Parameters: tensor([58.0993, 22.0397], grad_fn=<CopySlices>) logpdf: -5.834046363830566
Parameters: tensor([58.0192, 22.0794], grad_fn=<CopySlices>) logpdf: -5.8814778327941895

Warmup:   0%|          | 5/2500 [00:00, 21.87it/s, step size=6.08e-01, acc. prob=0.659]Parameters: tensor([60.4149, 16.7628], grad_fn=<CopySlices>) logpdf: -6.121933460235596
Parameters: tensor([59.8736, 15.3839], grad_fn=<CopySlices>) logpdf: -7.189225196838379
Parameters: tensor([59.2717, 16.3354], grad_fn=<CopySlices>) logpdf: -6.452694892883301
Parameters: tensor([60.3831, 15.2356], grad_fn=<CopySlices>) logpdf: -7.341385364532471
Parameters: tensor([59.3250, 18.0749], grad_fn=<CopySlices>) logpdf: -5.472909450531006
Parameters: tensor([59.0927, 21.3579], grad_fn=<CopySlices>) logpdf: -5.3234734535217285
Parameters: tensor([60.3876, 18.5206], grad_fn=<CopySlices>) logpdf: -5.2906341552734375
Parameters: tensor([58.5527, 22.2232], grad_fn=<CopySlices>) logpdf: -5.760471820831299
Parameters: tensor([59.9282, 22.5524], grad_fn=<CopySlices>) logpdf: -5.7087202072143555
Parameters: tensor([60.3831, 18.6745], grad_fn=<CopySlices>) logpdf: -5.247136116027832
Parameters: tensor([60.9989, 22.8593], grad_fn=<CopySlices>) logpdf: -5.974071979522705
Parameters: tensor([61.2877, 20.1574], grad_fn=<CopySlices>) logpdf: -5.225040912628174

Warmup:   0%|          | 9/2500 [00:00, 27.10it/s, step size=2.89e+00, acc. prob=0.781]Parameters: tensor([12.4390, 12.8969], grad_fn=<CopySlices>) logpdf: -236.30734252929688
Parameters: tensor([61.0980, 19.2484], grad_fn=<CopySlices>) logpdf: -5.233799934387207
Parameters: tensor([61.5800, 19.8165], grad_fn=<CopySlices>) logpdf: -5.309767246246338
Parameters: tensor([61.9148, 20.3383], grad_fn=<CopySlices>) logpdf: -5.434859752655029
Parameters: tensor([59.3125, 18.1233], grad_fn=<CopySlices>) logpdf: -5.456225395202637
Parameters: tensor([57.7870, 17.6128], grad_fn=<CopySlices>) logpdf: -6.116353988647461
Parameters: tensor([56.0189, 17.2025], grad_fn=<CopySlices>) logpdf: -7.4242472648620605
Parameters: tensor([54.9788, 17.0848], grad_fn=<CopySlices>) logpdf: -8.427913665771484
Parameters: tensor([55.6384, 17.4468], grad_fn=<CopySlices>) logpdf: -7.610964775085449
Parameters: tensor([57.3395, 18.1507], grad_fn=<CopySlices>) logpdf: -6.106584072113037
Parameters: tensor([58.9591, 18.9433], grad_fn=<CopySlices>) logpdf: -5.276744365692139
Parameters: tensor([60.1335, 19.6954], grad_fn=<CopySlices>) logpdf: -5.067811012268066
Parameters: tensor([60.9240, 20.3548], grad_fn=<CopySlices>) logpdf: -5.154724597930908
Parameters: tensor([61.4568, 20.8947], grad_fn=<CopySlices>) logpdf: -5.349027633666992
Parameters: tensor([61.8257, 21.3061], grad_fn=<CopySlices>) logpdf: -5.5606369972229
Parameters: tensor([62.0895, 21.5957], grad_fn=<CopySlices>) logpdf: -5.747973442077637
Parameters: tensor([61.5651, 18.5663], grad_fn=<CopySlices>) logpdf: -5.5072784423828125
Parameters: tensor([61.9546, 16.9894], grad_fn=<CopySlices>) logpdf: -6.345172882080078
Parameters: tensor([62.2009, 16.2738], grad_fn=<CopySlices>) logpdf: -6.929638385772705
Parameters: tensor([62.3646, 16.5304], grad_fn=<CopySlices>) logpdf: -6.819708824157715
Parameters: tensor([62.4790, 17.7073], grad_fn=<CopySlices>) logpdf: -6.196922302246094
Parameters: tensor([62.5632, 19.4676], grad_fn=<CopySlices>) logpdf: -5.742120265960693
Parameters: tensor([62.6295, 21.0293], grad_fn=<CopySlices>) logpdf: -5.8541340827941895
Parameters: tensor([62.6851, 22.0048], grad_fn=<CopySlices>) logpdf: -6.179637432098389
Parameters: tensor([62.7336, 22.5538], grad_fn=<CopySlices>) logpdf: -6.456192493438721
Parameters: tensor([62.7769, 22.8654], grad_fn=<CopySlices>) logpdf: -6.648891448974609
Parameters: tensor([62.8156, 23.0416], grad_fn=<CopySlices>) logpdf: -6.7746663093566895
Parameters: tensor([62.8504, 23.1317], grad_fn=<CopySlices>) logpdf: -6.849977016448975
Parameters: tensor([62.8816, 23.1595], grad_fn=<CopySlices>) logpdf: -6.88533878326416
Parameters: tensor([62.9095, 23.1365], grad_fn=<CopySlices>) logpdf: -6.887068748474121
Parameters: tensor([62.9345, 23.0683], grad_fn=<CopySlices>) logpdf: -6.85927677154541
Parameters: tensor([62.9566, 22.9573], grad_fn=<CopySlices>) logpdf: -6.805470943450928
Parameters: tensor([62.9762, 22.8059], grad_fn=<CopySlices>) logpdf: -6.729840278625488
Parameters: tensor([62.9934, 22.6176], grad_fn=<CopySlices>) logpdf: -6.63796854019165
Parameters: tensor([63.0085, 22.3980], grad_fn=<CopySlices>) logpdf: -6.536852836608887
Parameters: tensor([63.0215, 22.1549], grad_fn=<CopySlices>) logpdf: -6.434106826782227
Parameters: tensor([63.0329, 21.8978], grad_fn=<CopySlices>) logpdf: -6.336756706237793
Parameters: tensor([63.0426, 21.6361], grad_fn=<CopySlices>) logpdf: -6.250203609466553
Parameters: tensor([63.0510, 21.3783], grad_fn=<CopySlices>) logpdf: -6.1775898933410645
Parameters: tensor([63.0582, 21.1302], grad_fn=<CopySlices>) logpdf: -6.119723796844482
Parameters: tensor([63.0643, 20.8937], grad_fn=<CopySlices>) logpdf: -6.075638771057129
Parameters: tensor([63.0696, 20.6674], grad_fn=<CopySlices>) logpdf: -6.04355001449585
Parameters: tensor([63.0742, 20.4470], grad_fn=<CopySlices>) logpdf: -6.021770477294922
Parameters: tensor([63.0781, 20.2271], grad_fn=<CopySlices>) logpdf: -6.009363174438477
Parameters: tensor([63.0815, 20.0032], grad_fn=<CopySlices>) logpdf: -6.006298065185547
...
Parameters: tensor([63.1067, 21.3486], grad_fn=<CopySlices>) logpdf: -6.203779220581055
Parameters: tensor([63.1067, 21.3584], grad_fn=<CopySlices>) logpdf: -6.206433296203613
Parameters: tensor([63.1067, 21.3613], grad_fn=<CopySlices>) logpdf: -6.207237720489502
Parameters: tensor([63.1067, 21.3648], grad_fn=<CopySlices>) logpdf: -6.208197593688965
Parameters: tensor([63.1067, 21.3809], grad_fn=<CopySlices>) logpdf: -6.212608814239502
Parameters: tensor([63.1067, 21.4843], grad_fn=<CopySlices>) logpdf: -6.242234230041504
Parameters: tensor([63.1067, 21.9276], grad_fn=<CopySlices>) logpdf: -6.393496513366699

Warmup:   1%|          | 20/2500 [00:02,  7.11it/s, step size=2.94e+00, acc. prob=0.791]Parameters: tensor([63.1063, 24.7035], grad_fn=<CopySlices>) logpdf: -8.233990669250488
Parameters: tensor([63.1055, 24.7031], grad_fn=<CopySlices>) logpdf: -8.233113288879395
Parameters: tensor([63.1027, 24.7022], grad_fn=<CopySlices>) logpdf: -8.230559349060059
Parameters: tensor([63.1067, 24.6969], grad_fn=<CopySlices>) logpdf: -8.227944374084473
Parameters: tensor([63.1067, 24.6268], grad_fn=<CopySlices>) logpdf: -8.16263484954834
Parameters: tensor([63.1067, 24.2199], grad_fn=<CopySlices>) logpdf: -7.80264949798584
Parameters: tensor([63.1067, 23.3164], grad_fn=<CopySlices>) logpdf: -7.121757984161377
Parameters: tensor([63.0928, 24.7001], grad_fn=<CopySlices>) logpdf: -8.222328186035156
Parameters: tensor([63.0488, 24.6922], grad_fn=<CopySlices>) logpdf: -8.187959671020508
Parameters: tensor([62.6789, 24.6165], grad_fn=<CopySlices>) logpdf: -7.905611038208008
Parameters: tensor([19.5507, 16.0417], grad_fn=<CopySlices>) logpdf: -170.23828125
Parameters: tensor([63.1067, 24.6826], grad_fn=<CopySlices>) logpdf: -8.214599609375
Parameters: tensor([63.1067, 23.0024], grad_fn=<CopySlices>) logpdf: -6.923335552215576
Parameters: tensor([63.1067, 22.0647], grad_fn=<CopySlices>) logpdf: -6.448219299316406
...
Parameters: tensor([63.1067, 14.8760], grad_fn=<CopySlices>) logpdf: -8.647482872009277
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605

Warmup:   1%|          | 31/2500 [00:07,  3.17it/s, step size=1.82e+01, acc. prob=0.815]Parameters: tensor([63.1067, 19.6790], grad_fn=<CopySlices>) logpdf: -6.032225131988525
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4853], grad_fn=<CopySlices>) logpdf: -6.04841423034668
Parameters: tensor([63.1067, 14.3368], grad_fn=<CopySlices>) logpdf: -9.229097366333008
Parameters: tensor([63.1067, 12.3104], grad_fn=<CopySlices>) logpdf: -11.934846878051758

Warmup:   1%|▏         | 32/2500 [00:07,  3.87it/s, step size=1.92e+01, acc. prob=0.816]Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
...
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875

Warmup:   3%|▎         | 79/2500 [03:57, 10.34s/it, step size=3.43e+01, acc. prob=0.812]Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
...
Parameters: tensor([63.1067, 11.3613], grad_fn=<CopySlices>) logpdf: -13.484560012817383
Parameters: tensor([63.1067,  8.2820], grad_fn=<CopySlices>) logpdf: -19.75299072265625
Parameters: tensor([63.1067,  8.2820], grad_fn=<CopySlices>) logpdf: -19.75299072265625
Parameters: tensor([63.1067,  8.2820], grad_fn=<CopySlices>) logpdf: -19.75299072265625
Parameters: tensor([63.1067,  8.2820], grad_fn=<CopySlices>) logpdf: -19.75299072265625
Parameters: tensor([63.1067,  8.2820], grad_fn=<CopySlices>) logpdf: -19.75299072265625
Parameters: tensor([-40.8808, -14.7086], grad_fn=<CopySlices>) logpdf: -1143.2193603515625

Warmup:   4%|▍         | 105/2500 [07:17,  2.98s/it, step size=1.95e+01, acc. prob=0.802]Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
...
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605

Warmup:   5%|▌         | 133/2500 [13:55, 13.41s/it, step size=2.12e+01, acc. prob=0.800]Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
...
Parameters: tensor([63.1067, 24.7042], grad_fn=<CopySlices>) logpdf: -8.234880447387695
Parameters: tensor([63.1067, 24.7042], grad_fn=<CopySlices>) logpdf: -8.234880447387695
Parameters: tensor([63.1067, 24.7042], grad_fn=<CopySlices>) logpdf: -8.234880447387695
Parameters: tensor([63.1067, 24.7042], grad_fn=<CopySlices>) logpdf: -8.234880447387695
Parameters: tensor([-40.8808,   3.8161], grad_fn=<CopySlices>) logpdf: -1048.9425048828125

Warmup:   9%|▉         | 233/2500 [24:50,  4.69s/it, step size=8.35e-03, acc. prob=0.785]Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
...
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605

Warmup:  19%|█▉        | 478/2500 [58:48,  8.80s/it, step size=1.26e-04, acc. prob=0.773]Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
Parameters: tensor([63.1067, 19.4855], grad_fn=<CopySlices>) logpdf: -6.0483927726745605
...
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875

Sample:  27%|██▋       | 678/2500 [1:28:36,  8.28s/it, step size=2.02e-04, acc. prob=0.871]Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
Parameters: tensor([63.1067, 21.3643], grad_fn=<CopySlices>) logpdf: -6.2080535888671875
...
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703

Sample:  33%|███▎      | 831/2500 [1:49:47,  7.81s/it, step size=2.02e-04, acc. prob=0.872]Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703
Parameters: tensor([63.1067, 22.9785], grad_fn=<CopySlices>) logpdf: -6.909046173095703

At first the sampling seems to work well, updating all the parameters, but quite quickly it seems to get stuck on the first paramater, while the second one varies (albeit slowly). The first parameter sometimes makes an excursion to a very different value, but quickly jumps back.

Sampling up to about 33% of the total samples, which includes the entire warmup, has taken about 1 hour. During this, a large number of likelihood evaluations are done, more than I would expect. When I do the same sampling but with my torch-distribution model, the entire sampling is finished very quickly and accurately.

to put it bluntly, this is a bad way to do numerical linear algebra (unless cov is a very small matrix). MultivariateNormal uses cholesky decompositions under the hood, which are much more numerically stable. not sure why you don’t want to use a battle-tested implementation. if you want to “test” whether factor statements work as expected it’s enough to do something like pyro.factor('my_factor', mvn.log_prob(...)).

in other words the difference in behavior probably has nothing to do with neutra per se, though note that neutra is not a panacea. there’s no guarantee that it’ll improve sampling. in fact it can make it worse if you don’t learn a good map, and learning a good map may be hard and/or require a fair amount of hyperparameter tuning.

1 Like

Hello Martin,

Thanks for your reply.

to put it bluntly, this is a bad way to do numerical linear algebra (unless cov is a very small matrix). MultivariateNormal uses cholesky decompositions under the hood, which are much more numerically stable.

Thanks for the tip. I would have expected my runs without Neutra mapping to fail as well if this would have been the issue though.

not sure why you don’t want to use a battle-tested implementation

no specific reason, I’m fine with using your suggestion of

pyro.factor('my_factor', mvn.log_prob(...))

I have changed my setup to the following:

def return_NewNormalPriorScaledTorchMVNFactorModel(
    mvn_configuration, register_info_dict=None
):
    """
    https://docs.pyro.ai/en/stable/poutine.html#pyro.poutine.handlers.scale
    """

    specific_dist_based_MVN_function_for_factor = partial(
        dist_based_MVN_function_for_factor,
        mvn_configuration=mvn_configuration,
    )

    def NewNormalPriorScaledTorchMVNFactorModel():
        """
        PriorScaledTorchMVNFactorModel
        """

        # Return scaled samples
        with pyro.poutine.scale(scale=1e-10):
            x = pyro.sample("x", dist.Normal(0, 1))
            y = pyro.sample("y", dist.Normal(0, 1))

        model_factor = pyro.factor(
            "custom_likelihood",
            specific_dist_based_MVN_function_for_factor(x=x, y=y),
        )

    return NewNormalPriorScaledTorchMVNFactorModel

where dist_based_MVN_function_for_factor is

def dist_based_MVN_function_for_factor(mvn_configuration, x, y):
    """
    Function that returns the log prob of a standard MVN
    """

    # Construct array of the arguments
    symbol_array = torch.zeros(2)
    symbol_array[0] = x
    symbol_array[1] = y

    #
    mean = mvn_configuration["mean"].astype(np.float32)
    covariance = mvn_configuration["covariance"].astype(np.float32)

    #
    mvn = dist.MultivariateNormal(
        loc=torch.tensor(mean),
        covariance_matrix=torch.tensor(covariance),
    )

    #
    logpdf = mvn.log_prob(symbol_array)

    print("Parameters: {} logpdf: {}".format(symbol_array, logpdf))

    return logpdf

Using this model in my NUTS + Neutra run results in the following:

Training the map:

[utility_functions.py:107 -            fit_guide ] 2023-03-14 14:12:01,480: [0] Elbo loss = 405.67
Parameters: tensor([-0.3196,  0.7989], grad_fn=<CopySlices>) logpdf: -404.1603698730469
Parameters: tensor([-0.2514,  0.1700], grad_fn=<CopySlices>) logpdf: -405.7932434082031
Parameters: tensor([-0.3616,  0.4539], grad_fn=<CopySlices>) logpdf: -406.0044250488281
Parameters: tensor([-0.2574,  0.6932], grad_fn=<CopySlices>) logpdf: -403.81744384765625
Parameters: tensor([-0.3608,  0.9890], grad_fn=<CopySlices>) logpdf: -403.9312744140625
...
Parameters: tensor([0.0622, 0.7065], grad_fn=<CopySlices>) logpdf: -399.92596435546875
Parameters: tensor([-0.0417,  1.4829], grad_fn=<CopySlices>) logpdf: -398.2359313964844
Parameters: tensor([-0.0165,  0.3430], grad_fn=<CopySlices>) logpdf: -402.28497314453125
Parameters: tensor([0.0985, 1.2362], grad_fn=<CopySlices>) logpdf: -397.473876953125
Parameters: tensor([-0.0056,  1.5095], grad_fn=<CopySlices>) logpdf: -397.704345703125
Parameters: tensor([-0.0491,  0.2260], grad_fn=<CopySlices>) logpdf: -403.13812255859375
Parameters: tensor([-0.0275,  1.5524], grad_fn=<CopySlices>) logpdf: -397.80914306640625
Parameters: tensor([0.1252, 1.6779], grad_fn=<CopySlices>) logpdf: -395.51715087890625
...
Parameters: tensor([21.4227, 16.8705], grad_fn=<CopySlices>) logpdf: -153.2473907470703
Parameters: tensor([21.4516, 21.5541], grad_fn=<CopySlices>) logpdf: -152.28660583496094
Parameters: tensor([21.4737, 21.5994], grad_fn=<CopySlices>) logpdf: -152.13084411621094
[utility_functions.py:107 -            fit_guide ] 2023-03-14 14:13:02,443: [3300] Elbo loss = 152.81
Parameters: tensor([21.5125, 20.3793], grad_fn=<CopySlices>) logpdf: -151.5901336669922
Parameters: tensor([21.2324, 22.0906], grad_fn=<CopySlices>) logpdf: -154.17689514160156
...
Parameters: tensor([52.8258, 17.3588], grad_fn=<CopySlices>) logpdf: -9.291794776916504
Parameters: tensor([52.8513, 23.6773], grad_fn=<CopySlices>) logpdf: -9.910046577453613
Parameters: tensor([52.5936, 21.5193], grad_fn=<CopySlices>) logpdf: -9.163655281066895
Parameters: tensor([52.8088, 20.1649], grad_fn=<CopySlices>) logpdf: -8.621402740478516
Parameters: tensor([52.9370, 22.5628], grad_fn=<CopySlices>) logpdf: -9.092679023742676

Sampling:

[generate_samples.py:267 -     generate_samples ] 2023-03-14 14:13:14,288: Running NUTS sampler through map

Warmup:   0%|          | 0/2500 [00:00, ?it/s]
Parameters: tensor([52.8272, 19.7090]) logpdf: -8.600695610046387
Parameters: tensor([52.4673, 18.6567]) logpdf: -9.301974296569824
Parameters: tensor([52.4673, 18.6567], grad_fn=<CopySlices>) logpdf: -9.301974296569824
Parameters: tensor([52.4673, 18.6567], grad_fn=<CopySlices>) logpdf: -9.301974296569824
Parameters: tensor([52.4673, 18.6567]) logpdf: -9.301974296569824
...
Parameters: tensor([53.0709, 18.8446], grad_fn=<CopySlices>) logpdf: -8.382020950317383
Parameters: tensor([53.0714, 15.1107], grad_fn=<CopySlices>) logpdf: -10.63835620880127
Parameters: tensor([53.0720, 18.4996], grad_fn=<CopySlices>) logpdf: -8.472179412841797

Warmup:   0%|          | 4/2500 [00:00,  8.53it/s, step size=1.66e+00, acc. prob=0.582]
Parameters: tensor([53.0790, 19.4717], grad_fn=<CopySlices>) logpdf: -8.26523208618164
Parameters: tensor([53.0790, 22.5858], grad_fn=<CopySlices>) logpdf: -8.905948638916016
...
Parameters: tensor([53.0808, 20.3306], grad_fn=<CopySlices>) logpdf: -8.245712280273438
Parameters: tensor([53.0808, 17.8235], grad_fn=<CopySlices>) logpdf: -8.708516120910645
Parameters: tensor([53.0808, 17.1348], grad_fn=<CopySlices>) logpdf: -9.055747985839844
Parameters: tensor([53.0808, 17.2660], grad_fn=<CopySlices>) logpdf: -8.982248306274414
Parameters: tensor([53.0808, 19.4713], grad_fn=<CopySlices>) logpdf: -8.262737274169922

Warmup:   0%|          | 10/2500 [00:03,  2.31it/s, step size=1.74e+01, acc. prob=0.796]
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
...
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883

Warmup:   1%|▏         | 35/2500 [01:03,  4.29s/it, step size=3.82e+01, acc. prob=0.807]
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
...
Parameters: tensor([53.0808, 20.2891], grad_fn=<CopySlices>) logpdf: -8.243144035339355
Parameters: tensor([53.0808, 20.2891], grad_fn=<CopySlices>) logpdf: -8.243144035339355
Parameters: tensor([53.0808, 20.2891], grad_fn=<CopySlices>) logpdf: -8.243144035339355

Warmup:   3%|▎         | 80/2500 [05:28,  5.34s/it, step size=1.23e+01, acc. prob=0.797]
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
...
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883

Warmup:   4%|▍         | 109/2500 [08:47,  8.25s/it, step size=5.08e-02, acc. prob=0.782]
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
...
Parameters: tensor([-34.2186,   4.5586], grad_fn=<CopySlices>) logpdf: -915.0059204101562
Parameters: tensor([-34.2186,   4.5586], grad_fn=<CopySlices>) logpdf: -915.0059204101562
Parameters: tensor([-34.2186,   4.5586], grad_fn=<CopySlices>) logpdf: -915.0059204101562
Parameters: tensor([-34.2186,   4.5586], grad_fn=<CopySlices>) logpdf: -915.0059204101562

Warmup:   9%|▉         | 225/2500 [24:44,  8.01s/it, step size=9.99e-03, acc. prob=0.771]
Parameters: tensor([53.0808, 24.7315], grad_fn=<CopySlices>) logpdf: -10.473495483398438
Parameters: tensor([53.0808, 24.7315], grad_fn=<CopySlices>) logpdf: -10.473495483398438
Parameters: tensor([53.0808, 24.7315], grad_fn=<CopySlices>) logpdf: -10.473495483398438
Parameters: tensor([53.0808, 24.7315], grad_fn=<CopySlices>) logpdf: -10.473495483398438
Parameters: tensor([53.0808, 24.7315], grad_fn=<CopySlices>) logpdf: -10.473495483398438
Parameters: tensor([53.0808, 24.7315], grad_fn=<CopySlices>) logpdf: -10.473495483398438
...
Parameters: tensor([53.0808,  7.6765], grad_fn=<CopySlices>) logpdf: -23.421594619750977
Parameters: tensor([53.0808,  7.6765], grad_fn=<CopySlices>) logpdf: -23.421594619750977
Parameters: tensor([53.0808,  7.6765], grad_fn=<CopySlices>) logpdf: -23.421594619750977
Parameters: tensor([53.0808,  7.6765], grad_fn=<CopySlices>) logpdf: -23.421594619750977

Warmup:  12%|█▏        | 291/2500 [32:45,  8.13s/it, step size=4.71e-03, acc. prob=0.771]
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
Parameters: tensor([53.0808, 24.0231], grad_fn=<CopySlices>) logpdf: -9.853334426879883
...
Parameters: tensor([53.0808,  7.6765], grad_fn=<CopySlices>) logpdf: -23.421594619750977
Parameters: tensor([53.0808,  7.6765], grad_fn=<CopySlices>) logpdf: -23.421594619750977
Parameters: tensor([53.0808,  7.6765], grad_fn=<CopySlices>) logpdf: -23.421594619750977
Parameters: tensor([53.0808,  7.6765], grad_fn=<CopySlices>) logpdf: -23.421594619750977

Warmup:  18%|█▊        | 445/2500 [53:45,  8.11s/it, step size=7.19e-03, acc. prob=0.776]
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
...
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
...
Sample:  25%|██▍       | 618/2500 [1:17:33,  9.35s/it, step size=1.35e-03, acc. prob=0.869]
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
Parameters: tensor([53.0808, 16.8731], grad_fn=<CopySlices>) logpdf: -9.212555885314941
...


Again, same behaviour. Generally stuck with (at least) one parameter, making a very occasional jump to another value before quickly returning to the previous one.

Running this took about an hour, due to the very inefficent sampling.

not really sure why you’re so keen on using neutra but before you do i’d suggest you get your model to work without it. in particular you might try some of the suggestions in this tutorial. also generally make sure you’re doing everything in 64-bit precision using this utility function.

Hi Martin,

I’m doing a science project where I want to combine sampling techniques with elements from classical mechanics. I need to be able to rely on a fixed base distribution, and the Neutra AutoNormalisingFlow reparametrisation would, in principle, allow me to do that.

My eventual purpose at this point should not really matter, though. The problem I am facing now is that sampling through a map using a model that contains a factor statement to inform the model of the likelihood does not seem to work well. (while sampling through a map with a model without a factor statement works fine)

before you do i’d suggest you get your model to work without it.

They do, as I have mentioned already

i.e. I can use the model with the factor statement in the direct NUTS sampling, and it behaves as it should:

  • the distribution that the sampling recovers matches with the MVN that I want it to uncover.
  • The sampling efficiency / performance is on par with using pyro.sample(torch.MultivariateNormal(..))

in particular you might try some of the suggestions in this tutorial.

While this tutorial contains many useful suggestions, I do currently rely on the Normalising flow and reparametrisation approach.

also generally make sure you’re doing everything in 64-bit precision using this utility function.

Is this available for the conventional PYRO? Looks like a Numpyro function, which I am currently not using.

Are there any diagnostics that might help us get some insight in the actual issue here? model traces? grads during sampling?

oh i see you’re using pyro. in that case just make sure all your parameters and data are .double()'d and do everything in 64 bit precision that way.

afaik the only difference between factor(..., mvn.log_prob) and sample(..., mvn) might be in initialization logic. so you might try different initialization strategies including fixing the same initial value with init_to_value (example).

1 Like

Hello Martin,

I have changed my configuration and tensor creation to use .double.

I get the following error message:

Traceback (most recent call last):
  File "/home/david/projects/hmc_project/repo/hmc_project_code/tests/2d_normal_tests/NUTS_MAP_sampling_test.py", line 218, in <module>
    generate_samples(
  File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/sampling/generate_samples.py", line 270, in generate_samples
    nuts_through_map_results = sampler_wrapper(
  File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/sampling/generate_samples.py", line 37, in sampler_wrapper
    sampler_results = sampler_function(config=config, **sampler_args)
  File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/sampling/nuts_aahmc_sampler.py", line 132, in nuts_aahmc_sampler
    mcmc.run()
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/api.py", line 563, in run
    for x, chain_id in self.sampler.run(*args, **kwargs):
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/api.py", line 223, in run
    for sample in _gen_samples(
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/api.py", line 144, in _gen_samples
    kernel.setup(warmup_steps, *args, **kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/hmc.py", line 333, in setup
    self._initialize_adapter()
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/hmc.py", line 320, in _initialize_adapter
    self._adapter.reset_step_size_adaptation(self._initial_params)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/adaptation.py", line 111, in reset_step_size_adaptation
    self.step_size = self._find_reasonable_step_size(z)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/hmc.py", line 175, in _find_reasonable_step_size
    z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/ops/integrator.py", line 32, in velocity_verlet
    z_next, r_next, z_grads, potential_energy = _single_step_verlet(
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/ops/integrator.py", line 54, in _single_step_verlet
    z_grads, potential_energy = potential_grad(potential_fn, z)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/ops/integrator.py", line 83, in potential_grad
    raise e
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/ops/integrator.py", line 76, in potential_grad
    potential_energy = potential_fn(z)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/util.py", line 278, in _potential_fn
    model_trace = poutine.trace(cond_model).get_trace(
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
    raise exc from e
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/reparam_messenger.py", line 139, in __call__
    return self.fn(*args, **kwargs)
  File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/analysis/model_analysis/model_defs.py", line 648, in NewNormalPriorScaledTorchMVNFactorModel
    x = pyro.sample("x", dist.Normal(0, 1))
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/primitives.py", line 163, in sample
    apply_stack(msg)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/runtime.py", line 213, in apply_stack
    frame._process_message(msg)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 154, in _process_message
    return method(msg)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/reparam_messenger.py", line 88, in _pyro_sample
    new_msg = reparam.apply(
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/reparam/neutra.py", line 101, in apply
    x_unconstrained = self.transform(z_unconstrained)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/torch/distributions/transforms.py", line 349, in __call__
    x = part(x)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/torch/distributions/transforms.py", line 155, in __call__
    y = self._call(x)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/distributions/transforms/block_autoregressive.py", line 136, in _call
    pre_activation, dy_dx = self.layers[idx](y.unsqueeze(-1))
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/distributions/transforms/block_autoregressive.py", line 283, in forward
    return (torch.matmul(w, x) + self.bias.unsqueeze(-1)).squeeze(-1), wpl
RuntimeError: expected scalar type Float but found Double
       Trace Shapes:    
        Param Sites:    
       Sample Sites:    
x_shared_latent dist | 2
               value | 2

This happens after the guide training.

Also to rule out some things:

  • Sampling without neutra reparameterisation does not give this issue, for neither my factor-based model or my dist-based model.
  • Sampling with the neutra reparameterisation does give this issue, for both my factor-based model and my dist-based model.

well this only works if you consistently make everything a double.

if you’re not sure how to do that you can also just do

torch.set_default_tensor_type(torch.DoubleTensor)

is it enough to do this at the top (after import torch) in my entry point script or does every file that imports torch need to do this?

you just need to do it once before you create any tensors. subsequent created numerical tensors will be double’d