I made 2d state space model using Numypro
but I failed in making VarianceCovariance matrix. Please give me some idea
def SSM_md2(y):
T = y.shape[0]
k = y.shape[1]
#sigma_1 : 状態空間の標準偏差
sigma_11 = numpyro.sample("sigma_11", dist.Exponential(1.))
sigma_12 = numpyro.sample("sigma_12", dist.Exponential(1.))
#分散共分散行列
Sigma_1 = jnp.array([[sigma_11, 0], [0, sigma_12]])
#print(Sigma_1)
#sigma_2 : 観測モデルの標準偏差
sigma_21 = numpyro.sample("sigma_21", dist.Exponential(1.))
sigma_22 = numpyro.sample("sigma_22", dist.Exponential(1.))
#分散共分散行列
Sigma_2 = jnp.array([[sigma_21, 0], [0, sigma_22]])
x0 = numpyro.sample("x0", dist.MultivariateNormal(jnp.zeros(k), jnp.eye(k)))
#分散共分散行列
eps = numpyro.sample("eps", dist.MultivariateNormal(jnp.zeros(k), Sigma_1), sample_shape=(T,))
eta = numpyro.sample("eta", dist.MultivariateNormal(jnp.zeros(k), Sigma_2), sample_shape=(T,))
B = numpyro.sample("B", dist.Normal(0., 10.), sample_shape=(k, k))
#print(B)
#scanの中で使う関数を定義
def transition_fn(carry, t):
x_prev = carry
x_new = B @ x_prev + eps[t].reshape(k, 1)
print(x_new.shape)
y_ = numpyro.sample("y_t", dist.MultivariateNormal(x_new.reshape(k,), eta[t]))
print(y_.shape)
carry = x_new
return carry, y_
time_step = jnp.arange(T)
init = jnp.asarray(x0).reshape(k, 1)
with numpyro.handlers.condition(data={"y_t": y}):
scan(transition_fn, init, time_step)
I got this error
ValueError: MultivariateNormal distribution got invalid covariance_matrix parameter.