Hi!!
I require a little help understanding how to use Vindex for tensor slicing in my HMM version. I have read the documentation Mechanics of enumeration , GMM and HMM. I am following model_1 from the HMM, but a little different…
pytorch 1.8, cuda 10.2 ---> Although my cuda library is 11.3
pyro-ppl 1.6.0
-
My prob_x matrix, the transition probability matrix has dimensions == [n_sequences,hidden_dim,hidden_dim], instead of [hidden_dim,hidden_dim] as in the model 1 from the hmm.py example.
-
Therefore I require some slicing. I looked into Vindex, and I tried to use it to slice as follows:
def model(): .... n_sequences = 6 hidden_dim = 2 batch_size = 3 probs_x = torch.randn(n_sequences,hidden_dim,hidden_dim) with pyro.plate("sequences", size = n_sequences, size = batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 for t in pyro.markov(range(lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): prob_x = Vindex(probs_x)[batch,x] print("probs_x shape: {}".format(probs_x.shape)) x = pyro.sample("x_{}".format(t), dist.Categorical(prob_x), infer={"enumerate": "parallel"}) print("hidden state shape".format(x.shape)) print("...................")
guide = AutoDelta(poutine.block(model,expose_fn=lambda msg:msg[“name”].startswith(“OU_”)))
elbo = TraceEnum_ELBO(max_plate_nesting= 2,strict_enumeration_warning=True)
- I do not think I need ellipsis here, so I discarded that and other attempts of slicing
My simple mind cannot understand what is going on with the enumeration here. I understand that for each discrete variable we will perform and enumeration, and infere over the other continous variables, but I cannot seem to use it properly.
The error for the code above is the following:
probs_x shape: torch.Size([3, 2])
hidden state shape
…
probs_x shape: torch.Size([3, 3])
hidden state shape
…
probs_x shape: torch.Size([3, 3])Traceback (most recent call last):
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py”, line 165, in call
ret = self.fn(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)
File “/home/…/Draupnir_models.py”, line 2282, in model
infer={“enumerate”: “parallel”}) # [batch_size,1]–> [16,1,1] // [16, 1, 1, 1]
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/primitives.py”, line 156, in sample
apply_stack(msg)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py”, line 201, in apply_stack
default_process_message(msg)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py”, line 162, in default_process_message
msg[“value”] = msg[“fn”](*msg[“args”], **msg[“kwargs”])
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/distributions/torch_distribution.py”, line 46, in call
return self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape)
File “/home/…/anaconda3/lib/python3.7/site-packages/torch/distributions/categorical.py”, line 111, in sample
probs_2d = self.probs.reshape(-1, self._num_events)
RuntimeError: CUDA error: device-side assert triggeredThe above exception was the direct cause of the following exception:
Traceback (most recent call last):
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py”, line 165, in call
ret = self.fn(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/nn/module.py”, line 413, in call
return super().call(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 889, in _call_impl
result = self.forward(*input, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py”, line 379, in forward
self._setup_prototype(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py”, line 349, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py”, line 164, in _setup_prototype
self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py”, line 187, in get_trace
self(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py”, line 171, in call
raise exc from e
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py”, line 165, in call
ret = self.fn(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py”, line 12, in _context_wrap
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:197: sampleMultinomialOnce : block: [0return fn(*args, **kwargs),0
,0], thread: [1 File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py”, line 12, in _context_wrap
,0,0] Assertionval >= zero
failed.
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:197: sampleMultinomialOnce: block: [6,0,0], thread: [1,0,0] Assertionval >= zero
failed.
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:197: sampleMultinomialOnce: block: [3,0,0], thread: [1,0,0] Assertionval >= zero
failed.
return fn(*args, **kwargs)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)
File “/home/…/Draupnir_models.py”, line 2282, in model
infer={“enumerate”: “parallel”}) # [batch_size,1]–> [16,1,1] // [16, 1, 1, 1]
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/primitives.py”, line 156, in sample
apply_stack(msg)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py”, line 201, in apply_stack
default_process_message(msg)
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py”, line 162, in default_process_message
msg[“value”] = msg[“fn”](*msg[“args”], **msg[“kwargs”])
File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/distributions/torch_distribution.py”, line 46, in call
return self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape)
File “/home/…/anaconda3/lib/python3.7/site-packages/torch/distributions/categorical.py”, line 111, in sample
probs_2d = self.probs.reshape(-1, self._num_events)
RuntimeError: CUDA error: device-side assert triggered
Thanks in advance for any insights
(It could be a problem with cuda as well but just to be sure that the indexing is correct)