Model Based Machine Learning Book Chapter 2 Skills example in Pyro- tensor dimension issue

I am trying to implement the example from MBML Book in chapter two. The setup is that there are 22 test takers answering a 48 question test. The goal is to take the each person’s answers and find out which of 7 skills they likely have based on their responses (when we know which skills apply to each question). At the same time, you also need to model the probability of how hard it is to guess the answer to each of the 48 questions.

My code is sampling distributions using 1 dimensional tensors and some of the tensor lengths are different. For example

guess_probs = pyro.sample('guess_prob', dist.Beta(torch.ones(48)*2.5, torch.ones(48)*7.5))
and
skills = [pyro.sample('skill'+str(x), dist.Bernoulli(torch.ones(22) * 0.5)).int() for x in range(0, 7)]

The tensors for the probability of a person having each skill is of length 22, one for each person for that particular skill, but the variable I am sampling for the guess probabilities is of length 48, one for each question. I am getting an error from inside the step function of the pyro.infer.SVI object:

/usr/local/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in _compute_log_r(model_trace, guide_trace)
     20             if not model_site["is_observed"]:
     21                 log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
---> 22             log_r.add((stacks[name], log_r_term.detach()))
     23     return log_r
     24 

/usr/local/lib/python3.6/site-packages/pyro/infer/util.py in add(self, *items)
    113             assert all(f.dim < 0 and -len(value.shape) <= f.dim for f in frames)
    114             if frames in self:
--> 115                 self[frames] = self[frames] + value
    116             else:
    117                 self[frames] = value

RuntimeError: The size of tensor a (48) must match the size of tensor b (22) at non-singleton dimension 1

Is there a restriction that all pyro.param or pyro.sample statements must sample tensors of the same length?
I am also interested in whether using SVI to solve this type of problem is a workable approach- from some of the other threads I gather that there could be an issue with working with so many discrete variables.

1 Like

Is there a restriction that all pyro.param or pyro.sample statements must sample tensors of the same length?

no. can you paste your entire model? it’s hard to debug based on your snippet above, since the error isn’t coming from there, it’s most likely coming from your likelihood term

Entire model below. I am not confident that the way I have parameterized the Beta in the guide is appropriate.

Edit- I have fixed a bug in the guide that was distorting tensor dimensions

    import numpy as np
    import matplotlib.pyplot as plt
    import torch
    import pyro
    import pyro.infer
    import pyro.optim
    import pyro.distributions as dist
    import pandas

    from pyro.optim import Adam
    from pyro.infer import SVI, Trace_ELBO
    import torch.distributions.constraints as constraints
    from torch.distributions.beta import Beta
    from pyro.infer.mcmc import HMC, NUTS, MCMC
    from functools import reduce

    # Get data from Book's Website
    self_assessed = pandas.read_csv('http://www.mbmlbook.com/Downloads/LearningSkills_Real_Data_Experiments-Original-Inputs-RawResponsesAsDictionary.csv')
    self_assessed = self_assessed.iloc[1:,1:8]
    skills_key = pandas.read_csv('http://www.mbmlbook.com/Downloads/LearningSkills_Real_Data_Experiments-Original-Inputs-Quiz-SkillsQuestionsMask.csv', header=None)
    skills_needed = []
    for index, row in skills_key.iterrows():
       skills_needed.append([i for i,x in enumerate(row) if x])
    # col = person, row = question
    responses = pandas.read_csv('http://www.mbmlbook.com/Downloads/LearningSkills_Real_Data_Experiments-Original-Inputs-IsCorrect.csv', header=None)

    def reset_torch_seed(seed=99):
        torch.manual_seed(seed);

    def addNoiseTensor(skill, name, prob_mistake=0.1, prob_guess=0.2):
      # isCorrect1 is now a list of 22 values, 1 for each person that took the test
      # skill is a list of tensors. Each tensor is length 22, and represents whether the person 
      # had a skill required to answer this particular question

      has_all_req_skills = reduce(lambda x, y: x&y, skill)
      prob_correct = has_all_req_skills.float()
      prob_correct.apply_(lambda x: 1-prob_mistake if x==1.0 else prob_guess).float()
      return pyro.sample(name, dist.Bernoulli(prob_correct))  


    def SVI(vals, model, guide, lr = 0.0001, steps=1000):
        reset_torch_seed()
        pyro.clear_param_store()
        svi = pyro.infer.SVI(model = model,
                            guide = guide,
                            optim = pyro.optim.SGD({"lr":lr}),
                            #optim = pyro.optim.Adam({"lr":lr}), 
                            loss = pyro.infer.Trace_ELBO())
        
        losses = []
        for t in range(steps):
            losses.append(svi.step(*vals)) # vals will be passed to model and guide
        
        plt.plot(losses)
        plt.title("ELBO")
        plt.xlabel("step")
        plt.ylabel("loss")

    def complete_model_tensor():

      alpha0 = torch.ones(48) * 2.5
      beta0 = torch.ones(48) * 7.5
      skill_p = torch.ones(22) * 0.5 # Each of the seven skills needs 22 prior probs, 1 for each person
      guess_probs = pyro.sample('guess_prob', dist.Beta(alpha0, beta0))
      # converting to int here to make it easier to & tensors for checking if they have all skills required
      skills = [pyro.sample('skill'+str(x), dist.Bernoulli(skill_p)).int() for x in range(0, 7)]
      question_responses = []

      for question_no in range(0, 48):
        skills_for_question = skills_needed[question_no]
        isCorrect = addNoiseTensor([skills[i] for i in skills_for_question], 'isCorrect'+str(question_no), prob_guess=guess_probs[question_no].item())
        question_responses.append(isCorrect) # list of arrays of length 22
      return 

    def complete_model_tensor_guide():
      
      guess_prob_a = pyro.param('guess_prob_a', dist.Uniform(torch.ones(48), torch.ones(48)*20))
      guess_prob_b = pyro.param('guess_prob_b', dist.Uniform(torch.ones(48), torch.ones(48)*20))
     
      guess_probs = pyro.sample('guess_prob', dist.Beta(guess_prob_a , guess_prob_b))

      skill_p = [pyro.param('skill_p' + str(x), 
                               torch.abs(torch.rand(22)),
                               constraint=constraints.unit_interval) for x in range(0,7)]
      

      skills = [pyro.sample('skill'+str(x), dist.Bernoulli(skill_p[x])) for x in range(0,7)]
       
      return 

    # SVI Tensor version
    data = {'isCorrect'+str(index):torch.tensor(row.values.astype(float)).float() for index, row in responses.iterrows() }

    def conditioned_complete_model_tensor(*args, **kwargs):
        # Condition the model once (as opposed to creating 22 different models in model2)
        return pyro.condition(complete_model_tensor, data=data)(*args, **kwargs)

    SVI([], pyro.condition(complete_model_tensor, data=data), complete_model_tensor_guide, lr = 0.001, steps=1000)
1 Like

If it is helpful, if I turn on validation, the error message I get instead is:

/usr/local/lib/python3.6/site-packages/pyro/util.py in check_site_shape(site, max_plate_nesting)
    260                 '- enclose the batched tensor in a with plate(...): context',
    261                 '- .independent(...) the distribution being sampled',
--> 262                 '- .permute() data dimensions']))
    263 
    264     # TODO Check parallel dimensions on the left of max_plate_nesting.

ValueError: at site "guess_prob", invalid log_prob shape
  Expected [], actual [48]
  Try one of the following fixes:
  - enclose the batched tensor in a with plate(...): context
  - .independent(...) the distribution being sampled
  - .permute() data dimensions

a few suggestions:

  • try adding .independent(1) after your non-observed sample statements. this will resolve your current errors.
  • since the students and questions are all independent, replace your for loops with either a vectorized sample statement or an iarange, which will take care of some of the dim allocation issues for you. you’ll also get a performance speedup of a few orders of magnitude.

Thank you- part of my issue was this line:

skill_p = [pyro.param(‘skill_p’ + str(x),
torch.abs(torch.rand(1, 22)),
constraint=constraints.unit_interval) for x in range(0,7)]

I only wanted a 1 dimension vector of length 22, not 1 x 22. I have tried to read the tensor shape tutorial very carefully, but I am still a bit confused- if I use .independent(1), isn’t that actually declaring that the variables are dependent? It seems I am able to get rid of warnings if I use unique iaranges for each vector of random variables in both the model and the guide, but the answers are still not coming back close to what I expect.

Stepping back a bit, are these types of discrete latent models something you would expect Pyro to be able to do? Do I need to be using the enumeration feature?

The code appears to have another error.

Consider the lines:

    data = {'isCorrect'+str(index):torch.tensor(row.values.astype(float)).float() for index, row in responses.iterrows() }
    losses = SVI([], pyro.condition(model, data=data), guide, lr = 0.001, steps=1000)

It seems to me that all samples with the name ‘IsCorrect’+str(index) will be set to the value of data. However, that implies that in method AddNoise, the line:

samples = pyro.sample(name, dist.Bernoulli(prob_correct).independent(1))

will not have any effect. I could not understand why samples was printing the same value at every step, which also explains why the loss was not really decreasing much. Any comments are appreciated. How should this be fixed. I am a total beginner. Thanks.

 Gordon

Hi Gordon,

I am still working to get this to work, and I suspect I will need to use enumeration- are you trying to work through the book as well? With respect to your question, jpchen’s advice was to add the independent call to all the non-observed sample statements, so I have not added it to that spot in the code.

Jeff

Hi Jeff,

I am indeed working through the book. I have not yet wrapped my head around posteriors with Pyro yet, although I do kind of understand the tutorials. Your code does converge. I want to try to create synthetic data and reproduce the version from the book that is not functioning. I am also trying to understand traces.


Here is my Model and Guide (slightly modified from yours):

def model():
  # I WOULD LIKE MORE EFFICIENT matrix-based routine

  start = time.clock()
  alpha0 = torch.ones(48) * 2.5   # Why a Beta? Why not uniform? (book, p96, fig. 2.26)
  beta0 = torch.ones(48) * 7.5    # set to 0.2 in simplistic model. 
  skill_p = torch.ones(22) * 0.5  # Each of the seven skills needs 22 prior probs, 1 for each person. Prior. 
  #if alpha0 and beta0 are constant, won't guess_probs be all over the map? 
  guess_probs = pyro.sample('guess_prob', dist.Beta(alpha0, beta0))
  #print("model, guess_probs= ", guess_probs)   # guid and model have the same guess_probs. How is this possible? 
  prob_mistake = 0.1

  skills = []
  for x in pyro.irange(0, 7):
      #  p(skill) [0,1] prior
      skill = pyro.sample('skill'+str(x), dist.Bernoulli(skill_p).independent(1)).int()
      skills.append(skill)   # can change each time

  # find a way to make this loop more efficient (q is question number)
  for q in range(0, 48):
    skill = [skills[i] for i in skills_needed[q]]
    prob_guess = guess_probs[q].item()
    has_all_req_skills = reduce(lambda x, y: x&y, skill)
    prob_correct = has_all_req_skills.float()   # convert to float
    prob_correct.apply_(lambda x: 1-prob_mistake if x==1. else prob_guess).float()

    isCorrect = pyro.sample("isCorrect"+str(q), dist.Bernoulli(prob_correct).independent(1))
  print("model: %f sec" % (time.clock()-start))
  return

def guide():

  start = time.clock()
  guess_prob_a = pyro.param('guess_prob_a', dist.Uniform(torch.ones(48), torch.ones(48)*20))
  guess_prob_b = pyro.param('guess_prob_b', dist.Uniform(torch.ones(48), torch.ones(48)*20))

  # why is are the arrays guess_prob_a random variables instead of constant? Shouldn't the guide have simpler structure?
  guess_probs = pyro.sample('guess_prob', dist.Beta(guess_prob_a , guess_prob_b))

  skill_p = []
  for x in range(7):
     sp = pyro.param('skill_p'+str(x), torch.abs(torch.rand(22)), constraint=constraints.unit_interval)
     skill_p.append(sp)

  # code works
  for x in range(0,7):
      ss = pyro.sample('skill'+str(x), dist.Bernoulli(skill_p[x]).independent(1))

  print("guide: %f sec" % (time.clock()-start))

  return

This is a neat model! I think the best inference algorithm would be to use enumeration in the model, so that the skill variables don’t even appear in the guide. You’ll need to use Pyro dev branch at the moment (or Pyro 0.3 once that’s released). Something like this

def model(data):
    ...
    # vectorize over participants
    with pyro.plate("participants", len(data)):
        skills = []
        for i in pyro.plate("skills", 7):
           skills.append(pyro.sample("skill_{}".format(i),
                                      dist.Bernoulli(skill_p),
                                      infer={"enumerate": "parallel"})
        for q in pyro.plate("questions", 48):
            has_skills = reduce(operator.mul,
                                [skills[i] for i in skills_needed[q]])
            prob_correct = ...
            pyro.sample("is_correct_{}".format(q),
                        dist.Bernoulli(prob_correct),
                        obs=data[:, q])  # <--- I may be indexing incorrectly here

Then the guide only needs to sample the Betas, and you can use SVI(model, guide, optim, TraceEnum_ELBO(max_plate_nesting=1)). Also I think vectorizing over participants will be best. If you enumerate, you won’t be able to vectorize over skills or questions. (Note I’ve used our newer pyro.plate rather than the older irange and iarange).

Thanks @fritzo, I have tried to follow your advice and made the changes below to my model and guide. I am getting the following error and I am not sure what is incorrect:

“ValueError: Error while computing log_prob at site ‘skill_0’: The value argument must be within the support”

    prob_mistake=0.1

    def SVI(vals, model, guide, lr = 0.0001, steps=1000):
        reset_torch_seed()
        pyro.clear_param_store()
        svi = pyro.infer.SVI(model = model,
                             guide = guide,
                             #optim = pyro.optim.SGD({"lr":lr}),
                            optim = pyro.optim.Adam({"lr":lr}), 
                            loss = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1))
        
        losses = []
        for t in range(steps):
            losses.append(svi.step(*vals)) # vals will be passed to model and guide
        
        plt.plot(losses)
        plt.title("ELBO")
        plt.xlabel("step")
        plt.ylabel("loss")

    def complete_model_tensor():

      alpha0 = torch.ones(48) * 2.5
      beta0 = torch.ones(48) * 7.5

      # Not sure if I need the skill_p vector anymore
      skill_p = torch.ones(22) * 0.5 # Each of the seven skills needs 22 prior probs, 1 for each person
      guess_probs = pyro.sample('guess_prob', dist.Beta(alpha0, beta0))

      with pyro.plate("participants", 22):
            skills = []
            for i in pyro.plate("skills", 7):
               skills.append(pyro.sample("skill_{}".format(i),
                                          # I think I can replace skill_p in the call below
                                          # with a scalar because of the 22 in the 
                                          # participant plate?
                                          dist.Bernoulli(0.5), 
                                          infer={"enumerate": "parallel"}))
            
            for q in pyro.plate("questions", 48):
                has_skills = reduce(operator.mul,
                                    [skills[i] for i in skills_needed[q]])
                  
                prob_correct = has_skills.float()    
                has_skills.apply_(lambda x: 1-prob_mistake if x==1.0 else guess_probs[q])
                   
                pyro.sample("isCorrect{}".format(q),
                            dist.Bernoulli(prob_correct))

    def complete_model_tensor_guide():
      
      guess_prob_a = pyro.param('guess_prob_a', dist.Uniform(torch.ones(48), torch.ones(48)*20))
      guess_prob_b = pyro.param('guess_prob_b', dist.Uniform(torch.ones(48), torch.ones(48)*20))
      guess_probs = pyro.sample('guess_prob', dist.Beta(guess_prob_a , guess_prob_b))
      return 

    # SVI Tensor version
    data = {'isCorrect'+str(index):torch.tensor(row.values.astype(float)).float() for index, row in responses.iterrows() }

    SVI([], pyro.condition(complete_model_tensor, data=data), complete_model_tensor_guide, lr = 0.001, steps=10000)

Hi @jeffmax, I was able to get your snippet running by:

  • applying @jpchen’s suggestion of adding .independent(1)
  • replacing the .apply_() with an affine sum (I’m not sure how the .apply_() worked)
  • constraining the beta parameters to be positive and initializing with smaller values
def complete_model_tensor():
    alpha0 = torch.ones(48) * 2.5
    beta0 = torch.ones(48) * 7.5

    guess_probs = pyro.sample('guess_prob',
                              dist.Beta(alpha0, beta0).independent(1))

    with pyro.plate("participants", 22):
        skills = []
        for i in pyro.plate("skills", 7):
           skills.append(pyro.sample("skill_{}".format(i),
                                     dist.Bernoulli(0.5),
                                     infer={"enumerate": "parallel"}))

        for q in pyro.plate("questions", 48):
            has_skills = reduce(operator.mul,
                                [skills[i] for i in skills_needed[q]]).float()
            prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * guess_probs[q]
            pyro.sample("isCorrect{}".format(q),
                        dist.Bernoulli(prob_correct))

def complete_model_tensor_guide():
    guess_prob_a = pyro.param('guess_prob_a', torch.ones(48) * 4,
                              constraint=constraints.positive)
    guess_prob_b = pyro.param('guess_prob_b', torch.ones(48) * 4,
                              constraint=constraints.positive)
    guess_probs = pyro.sample('guess_prob',
                              dist.Beta(guess_prob_a , guess_prob_b).independent(1))

Does that look correct to you?

1 Like

Thanks @fritzo, this does make sense. This model is running and appears to be converging, and if I run it for enough steps, I get estimates very close to the answers provided. A couple of comments/questions:

  1. Can you provide any intuition behind why we would want to call .independent(1) on the guess_probs? My current understanding would be that this would reshape the distribution so as to be something like a 48 dimensional Beta distribution (I am not sure what that means, other than I think the draws would not be iid). If the point is to declare to Pyro’s inference engine that each guess probability could be dependent on the one’s prior, wouldn’t it suffice to just not put this sample in a plate (in which case Pyro would assume dependence?) I actually believe in this particular model, the guess probabilities are supposed to be independent, and I can get the model to run and converge if I put guess probs within a plate.
  2. It seems like the higher numbered guess probabilities (guess probability for question 46, as opposed to question 1, for example) converge to the answers slower? Does that make any sense?

Thanks for all your help!

I just wanted to mention that I took the same model code and ran it using MCMC and it worked great (getting answers very close to the suggested solutions).

I got the final version of the code to run. However, if I remove the independent(1) statements from the code, i.e., replace

    guess_probs = pyro.sample('guess_prob',
                                  dist.Beta(alpha0, beta0).independent(1))

by

guess_probs = pyro.sample('guess_prob',
                                  dist.Beta(alpha0, beta0))

and the same in the other location where to_event(1) or independent(1) (which are equivalent) are used, the code crashes with some kind of shape mismatch. So I would like to understand why. For that matter, I do not understand the need for to_event(1), since that decreases the assumptions of independence. One would use to_event(1) if one would like correlation in the corresponding slice. Thanks.

@jeffmax, how did you know where to download the Bishop’s datasets from? I have them because there are links from his various tables, but you download them via Python. Also, how do you get the data for Chapters 3 and 4, in particular the email data. Thanks.

@erlebach I am confused about the need for .to_event (I’ve described my confusion above as well). I don’t really have much intuition around what .to_event would even mean when applied to draws from a Bernoulli distribution (I can understand better for a normal distribution, because then I think it means that the covariance matrix is something other than the identity matrix). Perhaps it means that underlying algorithms should not rule out some kind of autoregressive dependence? I’ve asked a few related questions in this thread. @fritzo if you have a few minutes, could you chime in?

With respect to the datasets, I just used the CSV links next to the datasets in the book (the links you are referring to) and used Pandas to parse them into a dataframe. I am still working through Chapter 3 at the moment.

My issue with to_event(1) is not its use (as you say, it is about possible dependence), but rather, why the code crashes because of shape issues, when I do not use it. That is frustrating because I cannot move on. I am stuck.

So, I can remove the .to_event if I also put the tensors inside a plate, which I guess annotates them as actually independent (rather than leaving them unannotated).

   with pyro.plate("probs", 48):
         guess_probs = pyro.sample('guess_prob', dist.Beta(2.5, 7.5))

@erlebach as you’ve noticed, Pyro is strict about requiring either .to_event() or with pyro.plate.

Historically in Pyro’s development, we initially tried to be less strict with tensor shapes. However we found that this laxness led to many silent errors in mis-specified models. These often involved tensors that were mis-shaped but which could be combined via broadcasting (in an unintended way). More recently we’ve tried to make Pyro as strict as possible so as to surface as many model errors as we can possibly detect by syntactic means.