def semisupervised_loss_and_grads(model, guide, *args, **kwargs):
alpha = ALPHA['alpha']
print(alpha)
batch_size = len(args[0])
num_particle = NUM_PARTICLE
elbo_u, elbo_s = 0., 0.
guide_traces, model_traces = [], []
for _ in range(num_particle):
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
guide_traces.append(guide_trace)
model_traces.append(model_trace)
model_sup_particle = model_trace.log_prob_sum(
lambda name, site: site['type'] == 'sample' and name.find('sup') == 0)
model_unsup_particle = model_trace.log_prob_sum(
lambda name, site: site['type'] == 'sample' and name.find('unsup') == 0)
guide_particle = guide_trace.log_prob_sum()
elbo_u += (model_unsup_particle - guide_particle)
elbo_s += model_sup_particle
# Must multiply N/M <AKA len(unsupervised_dataloader)> to unsupervised ELBO
# where N: data size and M: batch size
batch_constant = len(unsupervised_dataloader)
elbo = batch_constant / num_particle * (elbo_u + alpha * elbo_s)
surrogate_theta_particle = 0.
surrogate_phi_particle = 0.
for model_trace, guide_trace in zip(model_traces, guide_traces):
guide_particle = guide_trace.log_prob_sum()
# Compute theta gradient expectation
model_z_sup_particle = model_trace.log_prob_sum(
lambda name, site: site['type'] == 'sample' and name.find('sup_first_name') == 0)
model_z_unsup_particle = model_trace.log_prob_sum(
lambda name, site: site['type'] == 'sample' and name.find('unsup_first_name') == 0)
surrogate_theta_particle += model_z_unsup_particle + (alpha * model_z_sup_particle)
# Compute phi gradient expectation
surrogate_phi_particle += (elbo_u - 1).detach() * guide_particle
# Scale the gradient functions by N/M and num_particle
surrogate_theta_particle = -batch_constant / num_particle * surrogate_theta_particle
surrogate_phi_particle = -batch_constant / num_particle * surrogate_phi_particle
# Backprop on theta and phi gradient function
surrogate_theta_particle.backward(retain_graph=True)
**surrogate_phi_particle.backward(retain_graph=True)**
return -elbo
I have this code, which samples from a OneHotCategorical distribution, but when I change it to a RelaxedOneHotCategoricalStraightThrough it breaks. What’s going on? Why would it have this error when it’s a RelaxedOneHotCateogoricalStraightThrough? The temp tensor for RelaxedOneHotCategoricalStraightThrough is a floattensor that is set to DEVICE, which is determined at the beginning depending on user’s computer. Could be CPU or CUDA
this happens at line “surrogate_phi_particle.backward(retain_graph=True)”