Hello I am looking for a function f that would transform my 2 channel softmax into binary vector in a differentiable manner so for example in pseudocode, or from different perspective differentiable round operation
f([0.15, 0.85]) = [0,1] f([0.6,0.4])=[1,0]
Next on this binarized vectors I would calculate loss (1 indicate that the pixel is in the area of the image and then I calculate the texture properties of this area )
Now I had found that something that could probably help me is gumbel softmax. In pyro RelaxedBernoulli seem even better and RelaxedBernoulliStraightThrough seem perfect.
Hovewer first of all I do not know weather sampling from for example RelaxedBernoulli is immidiately differentiable or it requires some special treatment.
I am using numpyro as the main part of the model is in jax
So in code for example for simplicity just using single number - 0.9
import numpyro.distributions as dist def differentiable_round(number,key): return dist.RelaxedBernoulli(temperature=0.05,probs=jnp.array([number])).sample( key,(1,)) differentiable_round(0.9,random.PRNGKey(2)) #gives 0.9999999
Thanks for help !