What is the difference between the following two sampling methods?

Hi devs and community,

What is the difference between the following two sampling techniques used in a NumPyro model?

Suppose that b_mean is a 2x1 vector

b = numpyro.sample(
    "b",
    dist.MultivariateNormal(b_mean, jnp.diag(b_scale)),
    sample_shape=(10,)
)
b = numpyro.sample(
    "b",
    dist.Normal(b_mean, b_scale),
    sample_shape=(10,)
)

Both will result in b being a 10 x 2 vector. However, I’m getting slightly different results with the two approaches. Also, the number of divergences decrease a lot when using the Multivariate Normal.

Could you please explain the reason for this?

1 Like

they should be approximately the same if you use dist.MultivariateNormal(b_mean, jnp.diag(b_scale ** 2)) (though MVN will be slower). note the scale parameter to a Normal distribution is a square root variance.

also note that it’s not recommended to pass sample_shape to sample statements. none of the examples or documentation do this. instead use plate or expand parameters etc

1 Like

Hi @martinjankowiak Thanks for clarifying that bit. I found sample_shape as one of the parameters in NumPyro’s Primitives document. There’s a similar parameter in PyMC3 called shape so it was quite intuitive for me to use it. Is there any reason why it is discouraged over expand?

@mathlad i’m not entirely sure if it’ll give the correct behavior in all cases.

@fehiepsi would know better.

in any case it’s unidiomatic (at least according to our standard usage) so might be confusing to anyone reading it. kind of like import torch as numpy could be used but might be confusing ; )

1 Like

Currently, we will pass keyword arguments from numpyro.sample primitive to Distribution.sample method. Using plate is better because many inference algorithms use the information provided by it.

sample_shape is different from expand. The former adds additional dimensions into the left. The later one behaves more like broadcasting. Using dist.Normal(0, 1) with sample_shape=(3, 4) is equivalent to dist.Normal(0, 1).expand([3, 4]), or dist.Normal(0, 1).expand([4]) with sample_shape=(3,). We typically use expand to make the shape of the distribution clearer to see. sample_shape might not work well with plate btw.

2 Likes

Thanks @fehiepsi

I want to define this hierarchical dependency

b[i,j] \sim N(b_\text{mean}[j], b_\text{scale}[j]) \;\; \forall j \in \{0, 1\} \;\; \forall i \in \{0, 1, 2\}

b_\text{mean}[j] \sim N(0, 1) \;\; \forall j \in \{0, 1\}
b_\text{scale}[j] \sim HN(2) \;\; \forall j \in \{0, 1\}

Currently, I use the following code in my numpyro model:

b_mean = numpyro.sample('b_mean', dist.Normal(0, 1), sample_shape=(2,))
b_scale = numpyro.sample('b_scale', dist.HalfNormal(2), sample_shape=(2,))
b = numpyro.sample('b', dist.Normal(b_mean, b_scale), sample_shape=(3,))

Here, b_mean , b_scale, and b get the shapes (3,), (3,), and (3, 2) which satisfies my requirements.

But it seems from the replies that this is wrong or does not work as intended?

And instead, I should do it like this?

b_mean = numpyro.sample('b_mean', dist.Normal(0, 1).expand((2,)))
b_scale = numpyro.sample('b_scale', dist.HalfNormal(2).expand((2,)))
b = numpyro.sample('b', dist.Normal(b_mean, b_scale).expand((3,2)))

Could you please confirm if the second code chunk is correct? Also, is there a way to see or test the difference between the two chunk of codes? In my code, I’m getting (exactly) the same estimates for these parameters, though.

@martinjankowiak I will also appreciate your input on this. Thank you both!

Also, I’m not sure what the equivalent code would be in plate notation.