Getting an error during inference fairly simple model: 'function' object has no attribute 'items'

I’ve built a simple generative model to produce an output matrix Y : N x T (the shape/dims). I have a single datum of a matrix D : NxT and I want to fit my generative model parameters so that it will produce samples that very closely match my one data matrix D. I’m hoping the trained generative model parameters will be able to tell me useful, interpretable information about the structure of my experimental data matrix D.

The “forward” pass of my model indeed generates samples that look the way I want, now I just want to run SVI or MCMC to infer the optimal model parameters to fit my data, but alas I get an error.

Here’s my model:

pyro.clear_param_store()
class scNMF:
    def __init__(self, data, M=5, tau=15): 
        #N = # neurons, T := time of recordingM = num motifs, tau = length (time) of motif
        #General parameters/input
        self.data = data
        N, T = data.shape
        self.N = N
        self.T = T
        self.M = M
        self.tau = tau
        self.tot = M * N * tau
        #priors
        self.noise_probs = th.Tensor([0.001]).repeat(N*T).reshape(N,T)
        self.motif_probs = th.Tensor([1/10]).repeat(self.tot).reshape(M, N, tau)
        self.act_probs = th.Tensor([1/T]).repeat(N*T).reshape(N,T)
        #pyro params
        self.noise_probs_guide = pyro.param('noise_probs', self.noise_probs, constraint=constraints.positive)
        self.motif_probs_guide = pyro.param('motif_probs',self.motif_probs, constraint=constraints.positive)
        self.act_probs_guide = pyro.param('act_probs', self.act_probs, constraint=constraints.positive)
        #output
        self.Y = None
        
    def model_(self, motifs, s, noise):
        Y_ = th.zeros(self.N,self.T)
        for i in range(self.M):
            a = motifs[i]
            Y_ += thconv(a,s[i])
        Y = Y_ + noise
        Y /= Y.max(dim=1,keepdim=True)[0].repeat(1,self.T)
        Ydist = pyro.sample('Y', dist.Bernoulli(Y))#, obs=self.data)
        '''sigma = th.zeros(Y.shape)
        sigma[:] = 0.001
        Ydist = pyro.sample('Y', dist.Normal(Y,sigma))
        self.Y = Ydist'''
        return Ydist
        
    def model(self):
        #model distributions
        noise_ = dist.Bernoulli(self.noise_probs)
        motifs_ = dist.Bernoulli(self.motif_probs)
        s_ = dist.Bernoulli(self.act_probs)
        
        #sampling
        noise = pyro.sample('noise',noise_)
        motifs = pyro.sample('motifs',motifs_)
        s = pyro.sample('s',s_)
        print(noise.shape,motifs.shape,s.shape)
        #generate samples
        return self.model_(motifs,s,noise)
    
    def guide(self):
        #guide distributions
        noise_ = dist.Bernoulli(self.noise_probs_guide)
        motifs_ = dist.Bernoulli(self.motif_probs_guide)
        s_ = dist.Bernoulli(self.act_probs_guide)
        #sampling
        noise = pyro.sample('noise',noise_)
        motifs = pyro.sample('motifs',motifs_)
        s = pyro.sample('s',s_)
        #generate samples
        return self.model_(motifs,s,noise)
    
    def conditioned_model(self):
        return pyro.condition(self.model, data={'Y': self.data})

As I said, if I instantiate this object and call the .model() method it successfully produces samples. But then I try inference (I tried MCMC and SVI), showing MCMC, I get this error:

nuts_kernel = NUTS(scnmf.conditioned_model, adapt_step_size=True)
hmc_posterior = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200).run()

Error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-543-2f95e5fde050> in <module>()
      1 nuts_kernel = NUTS(scnmf.conditioned_model, adapt_step_size=True)
----> 2 hmc_posterior = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200).run()

~/anaconda3/envs/deeprl/lib/python3.6/site-packages/pyro/infer/abstract_infer.py in run(self, *args, **kwargs)
     81         """
     82         self._init()
---> 83         for tr, logit in poutine.block(self._traces)(*args, **kwargs):
     84             self.exec_traces.append(tr)
     85             self.log_weights.append(logit)

~/anaconda3/envs/deeprl/lib/python3.6/site-packages/pyro/infer/mcmc/mcmc.py in _traces(self, *args, **kwargs)
     30 
     31     def _traces(self, *args, **kwargs):
---> 32         self.kernel.setup(*args, **kwargs)
     33         trace = self.kernel.initial_trace()
     34         self.logger.info("Starting MCMC using kernel - {} ...".format(self.kernel.__class__.__name__))

~/anaconda3/envs/deeprl/lib/python3.6/site-packages/pyro/infer/mcmc/hmc.py in setup(self, *args, **kwargs)
    186         if self._automatic_transform_enabled:
    187             self.transforms = {}
--> 188         for name, node in sorted(trace.iter_stochastic_nodes(), key=lambda x: x[0]):
    189             site_value = node["value"]
    190             if node["fn"].support is not constraints.real and self._automatic_transform_enabled:

~/anaconda3/envs/deeprl/lib/python3.6/site-packages/pyro/poutine/trace_struct.py in iter_stochastic_nodes(self)
    322         :return: an iterator over stochastic nodes in the trace.
    323         """
--> 324         for name, node in self.nodes.items():
    325             if node["type"] == "sample" and not node["is_observed"]:
    326                 yield name, node

AttributeError: 'function' object has no attribute 'items

I’m probably doing something stupid, it’s my first time using Pyro. Any help would be much appreciated!

Hi, you’re probably getting this error because pyro.condition returns a function, not a sample. You need to call that function in conditoned_model:

def conditioned_model(self):
    return pyro.condition(self.model, data={'Y': self.data})()  # call the result

@eb8680_2 I just tried it and I still get the same error unfortunately. I appreciate your help!
I’ll add that I get a similar error when trying inference with SVI:

AttributeError                            Traceback (most recent call last)
<ipython-input-15-79375a4390fb> in <module>()
      7 num_steps = 2500
      8 for t in range(num_steps):
----> 9     losses.append(svi.step())
     10     #a.append(pyro.param("a").item())
     11     #b.append(pyro.param("b").item())

~/anaconda3/envs/deeprl/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
     73         # get loss and compute gradients
     74         with poutine.trace(param_only=True) as param_capture:
---> 75             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
     76 
     77         params = set(site["value"].unconstrained()

~/anaconda3/envs/deeprl/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in __exit__(self, *args, **kwargs)
     68         """
     69         if self.param_only:
---> 70             for node in list(self.trace.nodes.values()):
     71                 if node["type"] != "param":
     72                     self.trace.remove_node(node["name"])

AttributeError: 'function' object has no attribute 'values'

What version of PyTorch and Pyro are you using? If not the latest (PyTorch v1.0.0 and Pyro v0.3.0) can you please upgrade and try again, and post a runnable snippet if you’re getting the same error? I can’t seem to reproduce that particular error.

Also, you might find it helpful to review some of our introductory tutorials if you haven’t already, as there are several other errors in your model:

  1. The variable Y is sampled in both model and guide via model_, but since it’s observed it should only appear in the model.
  2. MCMC doesn’t update parameters (param sites), only random variables (sample sites), and every random variable in your model and guide is discrete, so there’s nothing for NUTS (a gradient-based MCMC sampler for continuous random variables) to do. What posterior distributions do you want to infer?
  3. If you want to use SVI instead, you’ll need to call pyro.param inside the model or guide for parameters to be visible rather than just in the constructor.

@eb8680_2

I upgraded Pyro from 0.2 to 0.3.0 and fixed the things you noted and now it’s working. Not sure which of those things was the problem. Thank you so much, I really appreciate it! Happy New Year!

@eb8680_2 Rather than creating a new topic, I thought I’d reply with a follow up question here.

I have a working model and I can run inference with MCMC without errors. I also created my own guide function and I can run SVI except that it never converges, the loss just bounces around. However, when I use the AutoDiagonalNormal guide then I get a very smoothly decreasing error. Any idea what’s going on? My guide uses Beta distributions over the 3 parameters just as my model does.

Can you please post your code? Otherwise I can’t be of much help.

Here should be a complete set of working code @eb8680_2 (minus imports)

To recap, I’m building a model to infer a latent structure in a single datum which is a binary matrix of size N x T. Essentially I’ve built a generative model that uses a convolution operation to build the N x T matrix from convolving short a set of M motifs of size N x tau ( tau << T) with a sparse binary activation map 1 x T, plus some noise.

Here is my main model and guide:

class scNMF:
    def __init__(self, data, M=5, tau=15): 
        #N = # neurons, T := time of recordingM = num motifs, tau = length (time) of motif
        #General parameters/input
        self.data = data
        N, T = data.shape
        self.N = N
        self.T = T
        self.M = M
        self.tau = tau
        self.tot = M * N * tau
        
        self.alpha_prior = th.Tensor([1.])
        self.beta_prior = th.Tensor([50.])
        
        self.NT_prior_alpha = self.alpha_prior.repeat(N*T).reshape(N,T)
        self.NT_prior_beta = self.beta_prior.repeat(N*T).reshape(N,T)
        self.noise_probs_ = dist.Beta(self.NT_prior_alpha, self.NT_prior_beta).independent(2)
        
        self.MNtau_prior_alpha = self.alpha_prior.repeat(M*N*tau).reshape(M,N,tau)
        self.MNtau_prior_beta = self.beta_prior.repeat(M*N*tau).reshape(M,N,tau)
        self.motif_probs_ = dist.Beta(self.MNtau_prior_alpha, self.MNtau_prior_beta).independent(3)
        
        self.MT_prior_alpha = self.alpha_prior.repeat(M*T).reshape(M,T)
        self.MT_prior_beta = self.beta_prior.repeat(M*T).reshape(M,T)
        self.act_probs_ = dist.Beta(self.MT_prior_alpha, self.MT_prior_beta).independent(2)
        #output
        self.Y = None
        
    def model_(self, motifs, s, noise):
        Y_ = th.zeros(self.N,self.T)
        for i in range(self.M):
            a = motifs[i]
            Y_ += thconv(a,s[i])
        Y = Y_ + noise
        ymax = Y.max(dim=1,keepdim=True)[0]
        ymax = ymax.repeat(1,self.T)
        Y /= ymax
        #Y = th.nn.functional.normalize(Y,dim=1)
        #Y is now an NxT matrix representing probabilities for a bernoulli distribution
        return Y
        
    def model(self,data=None):
        #model distributions
        noise_probs = pyro.sample('noise', self.noise_probs_)
        motif_probs = pyro.sample('motifs', self.motif_probs_)
        act_probs = pyro.sample('s', self.act_probs_)
        noise_ = dist.Bernoulli(noise_probs).independent(2)
        motifs_ = dist.Bernoulli(motif_probs).independent(3)
        s_ = dist.Bernoulli(act_probs).independent(2)
        #sampling
        noise = noise_.sample()
        motifs = motifs_.sample()
        s = s_.sample()
        #generate samples
        Y = self.model_(motifs,s,noise)
        Ydist = pyro.sample('Y', dist.Bernoulli(Y).independent(2), obs=data)
        return Ydist
    
    def guide(self,data):
        #setup params
        n_alpha = pyro.param('n_alpha',self.NT_prior_alpha, constraint=constraints.positive)
        n_beta = pyro.param('n_beta',self.NT_prior_beta, constraint=constraints.positive)
        m_alpha = pyro.param('m_alpha',self.MNtau_prior_alpha, constraint=constraints.positive)
        m_beta = pyro.param('m_beta',self.MNtau_prior_beta, constraint=constraints.positive)
        s_alpha = pyro.param('s_alpha',self.MT_prior_alpha, constraint=constraints.positive)
        s_beta = pyro.param('s_beta',self.MT_prior_beta, constraint=constraints.positive)
        
        noise_probs = pyro.sample('noise', dist.Beta(n_alpha, n_beta).independent(2))
        motif_probs = pyro.sample('motifs', dist.Beta(m_alpha, m_beta).independent(3))
        act_probs = pyro.sample('s', dist.Beta(s_alpha, s_beta).independent(2))

    def conditioned_model(self,data):
        return pyro.condition(self.model, data={'Y': data})()

Here is my convolution function:

def thshift(x,shift=0):
    xnew = th.zeros(x.shape)
    if shift == 0:
        return x
    elif shift > 0:
        xnew[shift:] = x[0:-shift]
    else:
        xnew[:shift] = x[-shift:]
    return xnew

def thconv(a,s):
    res = th.stack([th.ger(a[:,i],thshift(s,i)) for i in range(a.shape[1])]).sum(dim=0)
    return res

And here is the code I’m running for SVI inference:


guide = AutoDiagonalNormal(scnmf.model)
latent_dim = 453375
pyro.param("auto_loc", th.abs(th.randn(latent_dim)),constraint=constraints.positive)
pyro.param("auto_scale", th.ones(latent_dim),
           constraint=constraints.positive)

pyro.clear_param_store()
svi = pyro.infer.SVI(model=scnmf.conditioned_model,
                     guide=guide,
                     optim=pyro.optim.Adam({"lr": 0.01}),
                     loss=pyro.infer.Trace_ELBO())

losses = []
num_steps = 3000
for t in range(num_steps):
    loss = svi.step(data)
    losses.append(loss)
    if t % 100 == 0:
        print("Epoch {}, Loss {}".format(t,loss))

So I’m having two issues now. One is that my custom guide doesn’t converge at all whereas the AutoDiagonalNormal with positivity constraints has a nice essentially monotonically decreasing loss plot. The second is that using the apparently well-optimized variational model the simulated data it generates still is a very poor approximation to my training data (again, a single N x T binary matrix).

Unfortunately, any inference results you’re getting right now are essentially meaningless, because some of the randomness in your model is unaccounted for. You need to wrap all randomness in your model with pyro.sample so that Pyro can see it:

noise = pyro.sample("noise_", noise_)  # not noise_.sample()
motifs = pyro.sample("motifs_", motifs_)  # not motifs_.sample()
s = pyro.sample("s_", s_)  # not s._sample()

You also need to include corresponding sample statements in the guide if you want to use SVI.

Also, some modeling advice - the way you’ve written your model, inference is going to be very difficult because you’re sampling very high-dimensional discrete variables. You may be getting confused by the unfortunately named .independent, which was renamed to .to_event in Pyro 0.3; you probably want to wrap those sample statements with pyro.plates instead. I recommend reading the latest tensor shape tutorial and enumeration tutorial and then seeing if you can rewrite your model to take advantage of Pyro’s machinery for efficiently enumerating discrete variables.

You might also try using continuous approximations in your model and guide in place of the discrete variables.

Thanks again for the tips.

Just as a start I tried replacing all .sample() with pyro.sample(...) and now I get a NotImplementedError: Cannot transform _Boolean constraints error with MCMC and SVI…

In the meantime I’ll read through the tutorials more carefully.

That is because neither AutoDiagonalNormal nor HMC can handle Bernoulli latent variables directly. Instead you can try to either

  1. enumerate out those latent variables via @config_enumerate (see the enumeration tutorial as @eb8680_2 suggests ); or
  2. write a custom guide for the discrete latent variables.