Thanks. So in pseudo-code to avoid writing hundreds of lines:
I’m making a DMM (similar to the tutorial) that learns the parameters from supervised data of the true hidden states (and observed states) and then during inference will infer the hidden states from the observed states once the model has been trained. So would the best approach be:
class DMM:
#some networks inheriting from nn.Module
self.trans = some_network
self.emitter = some_other_network
self.combiner = network_that_combines_previous_hidden_state_and_all_observations
def model(data_args):
z_prev = some_trainable_parameter
z_out = []
if data_args_hidden_labels is not None:
with pyro.plate("z_minibatch", len(mini_batch)) as batch:
for t in pyro.markov(the_range_of_t):
z_probs = self.trans(z_prev)
z = sample("z_%d" % t,dist.OneHotCategorical(z_probs),obs=data_args_hidden_labels[batch, t - 1, :],)
z_prev = z
z_out.append(z)
emission_probs_t = self.emitter(z)
x = pyro.sample("obs_x_%d" % t,
dist.RelaxedOneHotCategorical(torch.tensor(2.2), emission_probs_t),
obs=data_args_observed[batch, t - 1, :],
)
return z_out
else:
with pyro.plate("z_minibatch", len(mini_batch)) as batch:
for t in pyro.markov(the_range_of_t):
z_probs = self.trans(z_prev)
z = sample("z_%d" % t, dist.OneHotCategorical(z_probs),infer={"enumerate": "parallel"}, )
emission_probs_t = self.emitter(z)
x = pyro.sample("obs_x_%d" % t,
dist.RelaxedOneHotCategorical(torch.tensor(2.2), emission_probs_t),
obs=mini_batch[batch, t - 1, :],
)
z_prev = z
z_out.append(z)
return z_out
def pass_guide(data_args):
pass
@config_enumerate
def inference_guide(data_args):
z_out = []
with pyro.plate("z_minibatch", len(mini_batch)) as batch:
for t in pyro.markov(the_range_of_t):
z_loc = self.combiner(z_prev, rnn_output(data_args_observed_data))
z = pyro.sample("z_%d" % t, dist.OneHotCategorical(z_loc),infer={"enumerate": "parallel"},)
z_out.append(z)
z_prev = z
return z_out
#SVI train the model with Trace_ELBO using pass guide
#freeze all parameters except the combiner network from the guide and the initial state parameters from the guide
#SVI train the posterior (guide) enumerating the sites in the guide
#run inference using the now trained model and infer_discrete
Details omitted for brevity. Since the first batch of training (of the model) I’m maximising the likelihood of seeing the observed outcome by adjusting the parameters of the model then under the second training (of the guide) I’m deriving the closest posterior under the restrictions of the guide to the model. Then simply running inference using the best guess of the model. I guess this makes the guide superfluous in this use case? But could it then be used to further train with more unlabelled training examples (semi-supervised)?
For some reason as well, I keep getting KeyErrors with dim_to_symbol with this model. I’ll investigate this as it occurs in all models (supervised/unsupervised) that I’ve been running of this type. The model does seem to be handling the enumeration correctly when running an enumerated trace, and it is clearly taking up the enumeration instruction since the trace quotes an enumeration dimension for the relevant nodes.
On debugging it looks like an unenumerated tensor that should be enumerated has ended up with an extra dimension.