Model Based Machine Learning Book Chapter 2 Skills example in Pyro- tensor dimension issue

Hi Jeff,

I am indeed working through the book. I have not yet wrapped my head around posteriors with Pyro yet, although I do kind of understand the tutorials. Your code does converge. I want to try to create synthetic data and reproduce the version from the book that is not functioning. I am also trying to understand traces.


Here is my Model and Guide (slightly modified from yours):

def model():
  # I WOULD LIKE MORE EFFICIENT matrix-based routine

  start = time.clock()
  alpha0 = torch.ones(48) * 2.5   # Why a Beta? Why not uniform? (book, p96, fig. 2.26)
  beta0 = torch.ones(48) * 7.5    # set to 0.2 in simplistic model. 
  skill_p = torch.ones(22) * 0.5  # Each of the seven skills needs 22 prior probs, 1 for each person. Prior. 
  #if alpha0 and beta0 are constant, won't guess_probs be all over the map? 
  guess_probs = pyro.sample('guess_prob', dist.Beta(alpha0, beta0))
  #print("model, guess_probs= ", guess_probs)   # guid and model have the same guess_probs. How is this possible? 
  prob_mistake = 0.1

  skills = []
  for x in pyro.irange(0, 7):
      #  p(skill) [0,1] prior
      skill = pyro.sample('skill'+str(x), dist.Bernoulli(skill_p).independent(1)).int()
      skills.append(skill)   # can change each time

  # find a way to make this loop more efficient (q is question number)
  for q in range(0, 48):
    skill = [skills[i] for i in skills_needed[q]]
    prob_guess = guess_probs[q].item()
    has_all_req_skills = reduce(lambda x, y: x&y, skill)
    prob_correct = has_all_req_skills.float()   # convert to float
    prob_correct.apply_(lambda x: 1-prob_mistake if x==1. else prob_guess).float()

    isCorrect = pyro.sample("isCorrect"+str(q), dist.Bernoulli(prob_correct).independent(1))
  print("model: %f sec" % (time.clock()-start))
  return

def guide():

  start = time.clock()
  guess_prob_a = pyro.param('guess_prob_a', dist.Uniform(torch.ones(48), torch.ones(48)*20))
  guess_prob_b = pyro.param('guess_prob_b', dist.Uniform(torch.ones(48), torch.ones(48)*20))

  # why is are the arrays guess_prob_a random variables instead of constant? Shouldn't the guide have simpler structure?
  guess_probs = pyro.sample('guess_prob', dist.Beta(guess_prob_a , guess_prob_b))

  skill_p = []
  for x in range(7):
     sp = pyro.param('skill_p'+str(x), torch.abs(torch.rand(22)), constraint=constraints.unit_interval)
     skill_p.append(sp)

  # code works
  for x in range(0,7):
      ss = pyro.sample('skill'+str(x), dist.Bernoulli(skill_p[x]).independent(1))

  print("guide: %f sec" % (time.clock()-start))

  return

This is a neat model! I think the best inference algorithm would be to use enumeration in the model, so that the skill variables don’t even appear in the guide. You’ll need to use Pyro dev branch at the moment (or Pyro 0.3 once that’s released). Something like this

def model(data):
    ...
    # vectorize over participants
    with pyro.plate("participants", len(data)):
        skills = []
        for i in pyro.plate("skills", 7):
           skills.append(pyro.sample("skill_{}".format(i),
                                      dist.Bernoulli(skill_p),
                                      infer={"enumerate": "parallel"})
        for q in pyro.plate("questions", 48):
            has_skills = reduce(operator.mul,
                                [skills[i] for i in skills_needed[q]])
            prob_correct = ...
            pyro.sample("is_correct_{}".format(q),
                        dist.Bernoulli(prob_correct),
                        obs=data[:, q])  # <--- I may be indexing incorrectly here

Then the guide only needs to sample the Betas, and you can use SVI(model, guide, optim, TraceEnum_ELBO(max_plate_nesting=1)). Also I think vectorizing over participants will be best. If you enumerate, you won’t be able to vectorize over skills or questions. (Note I’ve used our newer pyro.plate rather than the older irange and iarange).

Thanks @fritzo, I have tried to follow your advice and made the changes below to my model and guide. I am getting the following error and I am not sure what is incorrect:

“ValueError: Error while computing log_prob at site ‘skill_0’: The value argument must be within the support”

    prob_mistake=0.1

    def SVI(vals, model, guide, lr = 0.0001, steps=1000):
        reset_torch_seed()
        pyro.clear_param_store()
        svi = pyro.infer.SVI(model = model,
                             guide = guide,
                             #optim = pyro.optim.SGD({"lr":lr}),
                            optim = pyro.optim.Adam({"lr":lr}), 
                            loss = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1))
        
        losses = []
        for t in range(steps):
            losses.append(svi.step(*vals)) # vals will be passed to model and guide
        
        plt.plot(losses)
        plt.title("ELBO")
        plt.xlabel("step")
        plt.ylabel("loss")

    def complete_model_tensor():

      alpha0 = torch.ones(48) * 2.5
      beta0 = torch.ones(48) * 7.5

      # Not sure if I need the skill_p vector anymore
      skill_p = torch.ones(22) * 0.5 # Each of the seven skills needs 22 prior probs, 1 for each person
      guess_probs = pyro.sample('guess_prob', dist.Beta(alpha0, beta0))

      with pyro.plate("participants", 22):
            skills = []
            for i in pyro.plate("skills", 7):
               skills.append(pyro.sample("skill_{}".format(i),
                                          # I think I can replace skill_p in the call below
                                          # with a scalar because of the 22 in the 
                                          # participant plate?
                                          dist.Bernoulli(0.5), 
                                          infer={"enumerate": "parallel"}))
            
            for q in pyro.plate("questions", 48):
                has_skills = reduce(operator.mul,
                                    [skills[i] for i in skills_needed[q]])
                  
                prob_correct = has_skills.float()    
                has_skills.apply_(lambda x: 1-prob_mistake if x==1.0 else guess_probs[q])
                   
                pyro.sample("isCorrect{}".format(q),
                            dist.Bernoulli(prob_correct))

    def complete_model_tensor_guide():
      
      guess_prob_a = pyro.param('guess_prob_a', dist.Uniform(torch.ones(48), torch.ones(48)*20))
      guess_prob_b = pyro.param('guess_prob_b', dist.Uniform(torch.ones(48), torch.ones(48)*20))
      guess_probs = pyro.sample('guess_prob', dist.Beta(guess_prob_a , guess_prob_b))
      return 

    # SVI Tensor version
    data = {'isCorrect'+str(index):torch.tensor(row.values.astype(float)).float() for index, row in responses.iterrows() }

    SVI([], pyro.condition(complete_model_tensor, data=data), complete_model_tensor_guide, lr = 0.001, steps=10000)

Hi @jeffmax, I was able to get your snippet running by:

  • applying @jpchen’s suggestion of adding .independent(1)
  • replacing the .apply_() with an affine sum (I’m not sure how the .apply_() worked)
  • constraining the beta parameters to be positive and initializing with smaller values
def complete_model_tensor():
    alpha0 = torch.ones(48) * 2.5
    beta0 = torch.ones(48) * 7.5

    guess_probs = pyro.sample('guess_prob',
                              dist.Beta(alpha0, beta0).independent(1))

    with pyro.plate("participants", 22):
        skills = []
        for i in pyro.plate("skills", 7):
           skills.append(pyro.sample("skill_{}".format(i),
                                     dist.Bernoulli(0.5),
                                     infer={"enumerate": "parallel"}))

        for q in pyro.plate("questions", 48):
            has_skills = reduce(operator.mul,
                                [skills[i] for i in skills_needed[q]]).float()
            prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * guess_probs[q]
            pyro.sample("isCorrect{}".format(q),
                        dist.Bernoulli(prob_correct))

def complete_model_tensor_guide():
    guess_prob_a = pyro.param('guess_prob_a', torch.ones(48) * 4,
                              constraint=constraints.positive)
    guess_prob_b = pyro.param('guess_prob_b', torch.ones(48) * 4,
                              constraint=constraints.positive)
    guess_probs = pyro.sample('guess_prob',
                              dist.Beta(guess_prob_a , guess_prob_b).independent(1))

Does that look correct to you?

1 Like

Thanks @fritzo, this does make sense. This model is running and appears to be converging, and if I run it for enough steps, I get estimates very close to the answers provided. A couple of comments/questions:

  1. Can you provide any intuition behind why we would want to call .independent(1) on the guess_probs? My current understanding would be that this would reshape the distribution so as to be something like a 48 dimensional Beta distribution (I am not sure what that means, other than I think the draws would not be iid). If the point is to declare to Pyro’s inference engine that each guess probability could be dependent on the one’s prior, wouldn’t it suffice to just not put this sample in a plate (in which case Pyro would assume dependence?) I actually believe in this particular model, the guess probabilities are supposed to be independent, and I can get the model to run and converge if I put guess probs within a plate.
  2. It seems like the higher numbered guess probabilities (guess probability for question 46, as opposed to question 1, for example) converge to the answers slower? Does that make any sense?

Thanks for all your help!

I just wanted to mention that I took the same model code and ran it using MCMC and it worked great (getting answers very close to the suggested solutions).

I got the final version of the code to run. However, if I remove the independent(1) statements from the code, i.e., replace

    guess_probs = pyro.sample('guess_prob',
                                  dist.Beta(alpha0, beta0).independent(1))

by

guess_probs = pyro.sample('guess_prob',
                                  dist.Beta(alpha0, beta0))

and the same in the other location where to_event(1) or independent(1) (which are equivalent) are used, the code crashes with some kind of shape mismatch. So I would like to understand why. For that matter, I do not understand the need for to_event(1), since that decreases the assumptions of independence. One would use to_event(1) if one would like correlation in the corresponding slice. Thanks.

@jeffmax, how did you know where to download the Bishop’s datasets from? I have them because there are links from his various tables, but you download them via Python. Also, how do you get the data for Chapters 3 and 4, in particular the email data. Thanks.

@erlebach I am confused about the need for .to_event (I’ve described my confusion above as well). I don’t really have much intuition around what .to_event would even mean when applied to draws from a Bernoulli distribution (I can understand better for a normal distribution, because then I think it means that the covariance matrix is something other than the identity matrix). Perhaps it means that underlying algorithms should not rule out some kind of autoregressive dependence? I’ve asked a few related questions in this thread. @fritzo if you have a few minutes, could you chime in?

With respect to the datasets, I just used the CSV links next to the datasets in the book (the links you are referring to) and used Pandas to parse them into a dataframe. I am still working through Chapter 3 at the moment.

My issue with to_event(1) is not its use (as you say, it is about possible dependence), but rather, why the code crashes because of shape issues, when I do not use it. That is frustrating because I cannot move on. I am stuck.

So, I can remove the .to_event if I also put the tensors inside a plate, which I guess annotates them as actually independent (rather than leaving them unannotated).

   with pyro.plate("probs", 48):
         guess_probs = pyro.sample('guess_prob', dist.Beta(2.5, 7.5))

@erlebach as you’ve noticed, Pyro is strict about requiring either .to_event() or with pyro.plate.

Historically in Pyro’s development, we initially tried to be less strict with tensor shapes. However we found that this laxness led to many silent errors in mis-specified models. These often involved tensors that were mis-shaped but which could be combined via broadcasting (in an unintended way). More recently we’ve tried to make Pyro as strict as possible so as to surface as many model errors as we can possibly detect by syntactic means.

Thanks, @fritzo. Actually, I had not made that particular connection. Good to know though. So with to_event(1) the variables are considered dependent, while with plate, they are independent. Is it that simple?

So the only time neither plate or to_event is required is when we are dealing with a scalar global variable?

So let me ask both @fritzo and Jeffmax:

If it is a choice between to_event(1) and plate, why not choose plate since the individual questions are independent of each other. Would the code not run faster? Could you please explain? Thanks.

Gordon

I am still confused on how to do inference in this model. I have inferred guess_probs using SVI and retrieved the skill sets using infer_discrete on the model. Given the model, the skill_set can only be 0 or 1, correct? However, in Chapter 2 of Winn & Bishop, Figure 2.7 suggests that the skills are real numbers between 0 and 1. I do not understand how to achieve that using the model they are using, or with the model and guide below. Does this sound confusing?

def complete_model_tensor():
    alpha0 = torch.ones(48) * 2.5
    beta0 = torch.ones(48) * 7.5

    guess_probs = pyro.sample('guess_prob',
                              dist.Beta(alpha0, beta0).independent(1))

    with pyro.plate("participants", 22):
        skills = []
        for i in pyro.plate("skills", 7):
           skills.append(pyro.sample("skill_{}".format(i),
                                     dist.Bernoulli(0.5),
                                     infer={"enumerate": "parallel"}))

        for q in pyro.plate("questions", 48):
            has_skills = reduce(operator.mul,
                                [skills[i] for i in skills_needed[q]]).float()
            prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * guess_probs[q]
            pyro.sample("isCorrect{}".format(q),
                        dist.Bernoulli(prob_correct))

def complete_model_tensor_guide():
    guess_prob_a = pyro.param('guess_prob_a', torch.ones(48) * 4,
                              constraint=constraints.positive)
    guess_prob_b = pyro.param('guess_prob_b', torch.ones(48) * 4,
                              constraint=constraints.positive)
    guess_probs = pyro.sample('guess_prob',
                              dist.Beta(guess_prob_a , guess_prob_b).independent(1))

My understanding of this (which could be wrong) is the following. I welcome any corrections!

Some algorithms that are able to take advantage of the independence assumptions could run slower if .to_event is used in a place where there is actual independence., but the answers shouldn’t be wrong. Sometimes I suspect .to_event is used because it is simpler to type in the code and may not be meaningfully different in a particular situation than using plate. I think, while not necessarily called out that often in the docs, especially if you turn on validation, Pyro wants all the sites to be annotated with plate or to_event, so you have to do one of them.

In my case, I use .to_event() for a variable x to specify that event_shape of the distribution which generates x is x.shape. I’ll use plate to denote the independence.

Sometimes, when they are equivalent and x.dim() >= 2 then I’ll use .to_event() to simply the code as @jeffmax pointed out. But I don’t recommend doing that way. Two plate statements are more explicit IMO.

In the following enumerated model, skills are enumerated. I count 2^(22*7) terms, which is huge, and yet it works. What is the real count? Thanks.

with pyro.plate("participants", 22):
        skills = []
        # Enumerate over skills (2 values for each skill)
        for i in pyro.plate("skills", 7):
           # skills: 0 or 1
           skills.append(pyro.sample("skill_{}".format(i),
                                     dist.Bernoulli(0.5),
                                     infer={"enumerate": "parallel"}))

        for q in pyro.plate("questions", 48):
            # skills_needed[q] is a list of skills need for qth question
            has_skills = reduce(operator.mul,
                                [skills[i] for i in skills_needed[q]]).float()
            prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * guess_probs[q]
            # Conditioned on the data. Done outside the routine
            is_correct = pyro.sample("isCorrect{}".format(q),
                        dist.Bernoulli(prob_correct))

I am currently trying to wrap my head around the pyro plate system and try to get my pyro notebook for chapter 2 up and running.

Therefore, I wondered if you can recommend resources next to the official tutorials?
Maybe you have your latest notebooks also somewhere online?
I wonder if you can use “a more vectorized” approach for the model in this chapter?

Any help would be greatly appreciated. :slight_smile:

@MicPie did you try using the code in the post above? That mostly worked for me, however I ended up needing to use MCMC to get answers close to the book in a reasonable amount of time.