Binary differentaible sampling - gumbel?

Hello I am looking for a function f that would transform my 2 channel softmax into binary vector in a differentiable manner so for example in pseudocode, or from different perspective differentiable round operation

f([0.15, 0.85]) = [0,1]
f([0.6,0.4])=[1,0]

Next on this binarized vectors I would calculate loss (1 indicate that the pixel is in the area of the image and then I calculate the texture properties of this area )

Now I had found that something that could probably help me is gumbel softmax. In pyro RelaxedBernoulli seem even better and RelaxedBernoulliStraightThrough seem perfect.

Hovewer first of all I do not know weather sampling from for example RelaxedBernoulli is immidiately differentiable or it requires some special treatment.

I am using numpyro as the main part of the model is in jax
So in code for example for simplicity just using single number - 0.9

import numpyro.distributions as dist
def differentiable_round(number,key):
    return dist.RelaxedBernoulli(temperature=0.05,probs=jnp.array([number])).sample( key,(1,))

differentiable_round(0.9,random.PRNGKey(2))
#gives 0.9999999

Thanks for help !