Trouble Initializing Parameters for SVI

As many have had this error before, I’m looking to find where exactly the parameters are “invalid” for the error message “RuntimeError: Cannot find valid initial parameters. Please check your model again.”. If I print out the values of certain parameters initialized by the “init_to_median” strategy, they are all reasonable and within the bounds of what is expected. I’ve also confirmed this error is not thrown when I fix one of the parameter values, instead of putting a prior distribution on it. I suspect that it is because I have a function that might not be differentiable with respect to this parameter (will show code below). Is there a clean way to check where exactly this error is being thrown and with regards to what parameters?

def make_convex_phi_prime(x, L, M = 1):
    assert len(x.shape) == 1, f"shape was {x.shape}, expected len 1" ### only have capacity for single dimension concavity
    eig_vals = jnp.squeeze(sqrt_eigenvalues(L, M, 1))
    broadcast_sub = jax.vmap(jax.vmap(jnp.subtract, (None, 0)), (0, None))
    broadcast_add = jax.vmap(jax.vmap(jnp.add, (None, 0)), (0, None))
    sum_eig_vals = broadcast_add(eig_vals, eig_vals)
    diff_eig_vals = broadcast_sub(eig_vals, eig_vals)
    x_shifted = x + L
    sin_pos = jnp.sin(jnp.einsum("t,m... -> tm...", x_shifted, sum_eig_vals)) / (2 * L * sum_eig_vals)
    sin_neg = jnp.sin(jnp.einsum("t,m... -> tm...", x_shifted, diff_eig_vals)) / (2 * L * diff_eig_vals)
    diagonal_elements = (x[..., None] - L)/ (2 * L) - jnp.diagonal(sin_pos, axis1=1, axis2=2)
    other_elements = sin_neg - sin_pos
    broadcast_fill_diag = jax.vmap(lambda x, y: jnp.fill_diagonal(x,y, inplace=False), in_axes = 0)
    phi_prime = broadcast_fill_diag(other_elements, diagonal_elements)
    return phi_prime #should be t x m x m where t is the length of x and m is the number of eigen values (or M) 

the issue was because this was not differentiable! the fix was following guidance from Best practice to handle division by zero in auto-differentiation · Issue #5039 · jax-ml/jax · GitHub