That Stan statement corresponds to an improper truncated distribution. To complete my previous suggestion, here is the corresponding code
import funsor
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
class TruncatedNormal(dist.Normal):
support = dist.constraints.positive
def sample(self, key, sample_shape=()):
return dist.TruncatedNormal(self.loc, self.scale, low=0.).sample(key, sample_shape=sample_shape)
def log_prob(self, value):
return dist.TruncatedNormal(self.loc, self.scale, low=0.).log_prob(value)
funsor.distribution.make_dist(TruncatedNormal, param_names=("loc", "scale"))
def model():
numpyro.sample("c", dist.Bernoulli(0.5), infer={"enumerate": "parallel"})
numpyro.sample("x", TruncatedNormal(2, 3))
mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=10000)
mcmc.run(jax.random.PRNGKey(0))
x = mcmc.get_samples()["x"]
plt.hist(x);
You can also define an improper distribution if you want. MCMC will give similar results
class ImproperTruncatedNormal(dist.Normal):
support = dist.constraints.positive
funsor.distribution.make_dist(ImproperTruncatedNormal, param_names=("loc", "scale"))
def model():
numpyro.sample("c", dist.Bernoulli(0.5), infer={"enumerate": "parallel"})
numpyro.sample("x", ImproperTruncatedNormal(2, 3))