@martinjankowiak here’s the code
Imports:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import jax
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)
numpyro.enable_x64()
seed = random.PRNGKey(0)
Data simulation:
n_subject = 2
n_segment = 3
a = jnp.array([
[2.4, 3.2, 5.4],
[2.53, 3.5, 5.9]
])
b = jnp.array([
[3.4, 3.9, 3.1],
[4.1, 3.99, 4.2]
])
n_points = 40
x = jnp.linspace(0, 10, n_points)
df = None
for i in range(n_subject):
for j in range(n_segment):
subject = jnp.repeat(i, n_points)
segment = jnp.repeat(j, n_points)
if not i and j:
mean = jnp.minimum(jnp.maximum(
0, b[i, j] * (x - a[i, j])
), 9)
else:
mean = jax.nn.relu(
b[i, j] * (x - a[i, j])
)
y = dist.TruncatedNormal(mean, .5, low=0).sample(seed, (1,))
if df is None:
df = pd.DataFrame(
jnp.array([subject, segment, x, y.reshape(-1,)]).T,
columns=['subject', 'segment', 'x', 'y']
)
else:
temp = pd.DataFrame(
jnp.array([subject, segment, x, y.reshape(-1,)]).T,
columns=['subject', 'segment', 'x', 'y']
)
df = pd.concat([df, temp], ignore_index=True).copy()
df.subject = df.subject.astype(int)
df.segment = df.segment.astype(int)
Plot simulated data:
fig, ax = plt.subplots(n_subject, n_segment, figsize=(12,6), constrained_layout=True)
for i in range(n_subject):
for j in range(n_segment):
sns.scatterplot(x='x', y='y', ax=ax[i][j], data=df[(df.subject == i) & (df.segment == j)])
ax[i][j].set_ylabel('Y')
ax[i][j].set_xlabel('X')
ax[i][j].set_title(f'Subject {i}, Segment {j}')
Looks like this:
As you can see, the data only saturates when subject = 0 and segment = 1, 2. So I want to fit from two parametric families: one being Relu and other being Relu that saturates (kind of like sigmoid)
Curve fitting with parametric family set to: Relu that saturates
def model(x, subject, segment, y_obs=None):
n_subject = np.unique(subject).shape[0]
n_segment = np.unique(segment).shape[0]
with numpyro.plate("n_segment", n_segment, dim=-1):
with numpyro.plate("n_subject", n_subject, dim=-2):
a = numpyro.sample('a', dist.Normal(3, 1))
b = numpyro.sample('b', dist.Normal(3, 1))
c = numpyro.sample('c', dist.Normal(9, 1))
mean = jnp.minimum(jnp.maximum(
0, b[subject, segment] * (x - a[subject, segment])
), c[subject, segment])
with numpyro.plate("data", len(x)):
return numpyro.sample("obs", dist.TruncatedNormal(mean, .5, low=0), obs=y_obs)
x = df.x.to_numpy().reshape(-1,)
subject = df.subject.to_numpy().reshape(-1,)
segment = df.segment.to_numpy().reshape(-1,)
y = df.y.to_numpy().reshape(-1,)
# MCMC
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_chains=4, num_warmup=2000, num_samples=4000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key, x, subject, segment, y)
posterior_samples = mcmc.get_samples()
Plot results
a = posterior_samples['a'].mean(axis=0)
b = posterior_samples['b'].mean(axis=0)
c = posterior_samples['c'].mean(axis=0)
fig, ax = plt.subplots(n_subject, n_segment, figsize=(12,6), constrained_layout=True)
for i in range(n_subject):
for j in range(n_segment):
sns.scatterplot(x='x', y='y', ax=ax[i][j], data=df[(df.subject == i) & (df.segment == j)])
ax[i][j].set_ylabel('Y')
ax[i][j].set_xlabel('X')
ax[i][j].set_title(f'Subject {i}, Segment {j}')
mean = jnp.minimum(jnp.maximum(
0, b[i, j] * (x - a[i, j])
), c[i, j])
sns.lineplot(x=x, y=mean, ax=ax[i][j], color = 'red')
which looks like this
Now, I will change my model to fit from both the parametric families
def model(x, subject, segment, y_obs=None):
n_subject = np.unique(subject).shape[0]
n_segment = np.unique(segment).shape[0]
with numpyro.plate("n_segment", n_segment, dim=-1):
with numpyro.plate("n_subject", n_subject, dim=-2):
a = numpyro.sample('a', dist.Normal(3, 1))
b = numpyro.sample('b', dist.Normal(3, 1))
c = numpyro.sample('c', dist.Normal(9, 1))
p = numpyro.sample('p', dist.Uniform(0, 1))
q = numpyro.sample('q', dist.Bernoulli(probs=p))
mean = \
q[subject, segment] * jnp.minimum(jnp.maximum(
0, b[subject, segment] * (x - a[subject, segment])
), c[subject, segment]) + \
(1 - q[subject, segment]) * jax.nn.relu(b[subject, segment] * (x - a[subject, segment]))
with numpyro.plate("data", len(x)):
return numpyro.sample("obs", dist.TruncatedNormal(mean, .5, low=0), obs=y_obs)
# MCMC
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_chains=4, num_warmup=2000, num_samples=4000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key, x, subject, segment, y)
posterior_samples = mcmc.get_samples()
Gives the error
ValueError: Expected the joint log density is a scalar, but got (240,). There seems to be something wrong at the following sites: {'_pyro_dim_2'}.