Unable to find good initial values in toy Pyro bayesian pcfg model

Hi folks,

I’m new to pyro, but it seems great - I’m trying to fit a very simple toy model of a bayesian pcfg (probabilistic context free grammar), but I’m getting a warning re: unable to find start values.

The model code is here:

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
from pyro import poutine
from torch.distributions.utils import lazy_property
import math

# Define grammar rules and their probabilities
grammar = {
    "S": [(0.8, ("NP", "VP")), (0.2, ("VP",))],
    "NP": [(0.6, ("Det", "N")), (0.4, ("N",))],
    "VP": [(0.7, ("V", "NP")), (0.3, ("V",))],
    "Det": [(0.6, ("the",)), (0.4, ("a",))],
    "N": [(0.5, ("dog",)), (0.5, ("cat",))],
    "V": [(1.0, ("chased",))]
}


# Functions to generate rules from the grammar
def generate_rules(grammar):
    rules = []
    for lhs, rhs_list in grammar.items():
        for probability, rhs in rhs_list:
            rules.append((lhs, rhs, probability))
    return rules
rules = generate_rules(grammar)

def logsumexp(x, y):
    m = max(x, y)
    return m + math.log(math.exp(x - m) + math.exp(y - m))

def cky_parse(sentence, rules):
    n = len(sentence)
    chart = [[{} for _ in range(n)] for _ in range(n)]

    # Fill the chart with terminal rules
    for i, word in enumerate(sentence):
        for rule in rules:
            lhs, rhs, log_probability = rule
            if rhs == (word,):
                chart[i][i][lhs] = log_probability

    # Apply nonterminal rules
    for span in range(2, n + 1):
        for start in range(n - span + 1):
            end = start + span - 1
            for split in range(start, end):
                for rule in rules:
                    lhs, rhs, log_probability = rule
                    if len(rhs) == 2:
                        lhs1, lhs2 = rhs
                        log_prob = chart[start][split].get(lhs1, float('-inf')) + chart[split + 1][end].get(lhs2, float('-inf')) + log_probability
                        chart[start][end][lhs] = logsumexp(chart[start][end].get(lhs, float('-inf')), log_prob)

    return chart

def pcfg_pyro_model(sentence, rules, updated_grammar):
    # Parse the sentence with the updated grammar
    chart = cky_parse(sentence, generate_rules(updated_grammar))

    # Target variable: the probability of generating the sentence
    target = pyro.sample("target", dist.Exponential(1.0))
    target_obs = chart[0][-1].get("S", 0)
    pyro.factor("obs", math.log(target_obs) - math.log(target))
def run_pcfg_inference(sentence, rules):
    # Initialize the rule probabilities using a ParamStoreDict
    param_store = pyro.get_param_store()
    for i, rule in enumerate(rules):
        param_name = f"rule_prob_{i}"
        if param_name not in param_store:
            param_store[param_name] = pyro.param(param_name, torch.tensor(rule[2]), constraint=dist.constraints.positive)

    # Update the grammar with the sampled rule probabilities
    updated_grammar = {nt: [(pyro.param(f"rule_prob_{i}"), r[1]) for i, r in enumerate(rules) if r[0] == nt] for nt in
                       grammar.keys()}

    # Run inference using NUTS and MCMC
    nuts_kernel = NUTS(lambda: pcfg_pyro_model(sentence, rules, updated_grammar))
    mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
    mcmc.run()

    # Inspect the results
    mcmc.summary()

# Example sentence to parse
sentence = ["the", "dog", "chased"]

# Perform inference
run_pcfg_inference(sentence, rules)



And the error I get is this:

Warmup:   0%|          | 0/1500 [00:00, ?it/s]Traceback (most recent call last):
  File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\api.py", line 563, in run
    for x, chain_id in self.sampler.run(*args, **kwargs):
  File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\api.py", line 223, in run
    for sample in _gen_samples(
  File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\api.py", line 144, in _gen_samples
    kernel.setup(warmup_steps, *args, **kwargs)
  File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\hmc.py", line 345, in setup
    self._initialize_model_properties(args, kwargs)
  File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\hmc.py", line 279, in _initialize_model_properties
    init_params, potential_fn, transforms, trace = initialize_model(
  File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\util.py", line 468, in initialize_model
    initial_params = _find_valid_initial_params(
  File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\util.py", line 365, in _find_valid_initial_params
    raise ValueError(
ValueError: Model specification seems incorrect - cannot find valid initial params.

Process finished with exit code 1

It seems like the model compiles okay, but there’s something about start values that it’s not liking. Any help would be very much appreciated!

Hi @cbreiss, you can use init_to_value to specify initial ones; - this can help you debug the issue. I suspect that there are numerical issues happen that lead to invalid joint density (you can add many print statements for debugging).

Hi @fehiepsi, thanks so much for your response. Looking at the documentation for that function, it seems to be an argument to a guide class, which I’m not using, since I’m doing mcmc. Does that mean that you think I should redo the model with variational inference, or am I misunderstanding what you meant?

Thanks again,
Canan

i think there may be an issue with your model. what observation likelihood is this supposed to encode? it would seem that this pushes target to infinity, which could explain your issue

yes, you can use it as init_strategy in NUTS