How to make 2d state space model

I made 2d state space model using Numypro

but I failed in making VarianceCovariance matrix. Please give me some idea

def SSM_md2(y):
    T = y.shape[0]
    k = y.shape[1]
    #sigma_1 : 状態空間の標準偏差
    sigma_11 = numpyro.sample("sigma_11", dist.Exponential(1.))
    sigma_12 = numpyro.sample("sigma_12", dist.Exponential(1.))
    #分散共分散行列
    Sigma_1 = jnp.array([[sigma_11, 0], [0, sigma_12]])

    #print(Sigma_1)
    #sigma_2 : 観測モデルの標準偏差
    sigma_21 = numpyro.sample("sigma_21", dist.Exponential(1.))
    sigma_22 = numpyro.sample("sigma_22", dist.Exponential(1.))
    #分散共分散行列
    Sigma_2 = jnp.array([[sigma_21, 0], [0, sigma_22]])
    x0 = numpyro.sample("x0", dist.MultivariateNormal(jnp.zeros(k), jnp.eye(k)))
    #分散共分散行列

    eps = numpyro.sample("eps", dist.MultivariateNormal(jnp.zeros(k), Sigma_1), sample_shape=(T,))
    eta = numpyro.sample("eta", dist.MultivariateNormal(jnp.zeros(k), Sigma_2), sample_shape=(T,))

    B = numpyro.sample("B", dist.Normal(0., 10.), sample_shape=(k, k))
    #print(B)
    #scanの中で使う関数を定義
    def transition_fn(carry, t):
        x_prev = carry
        x_new = B @ x_prev + eps[t].reshape(k, 1)
        print(x_new.shape)
        y_ = numpyro.sample("y_t", dist.MultivariateNormal(x_new.reshape(k,), eta[t]))
        print(y_.shape)
        carry = x_new
        return carry, y_
    
    time_step = jnp.arange(T)
    init = jnp.asarray(x0).reshape(k, 1)
    with numpyro.handlers.condition(data={"y_t": y}):
        scan(transition_fn, init, time_step)

I got this error

ValueError: MultivariateNormal distribution got invalid covariance_matrix parameter.

The covariance matrix of the MultivariateNormal needs to be positive-definite. I think you need to apply the following transforms: https://github.com/pyro-ppl/numpyro/blob/065aa4352dbc9ebeb1132b15b51d8b3149848640/numpyro/distributions/transforms.py#L1261-L1263

Thank you for your reply

Could you tell me how to implement it?

You can apply it like this:

ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv])(Sigma_1)