Hi folks,
I’m new to pyro, but it seems great - I’m trying to fit a very simple toy model of a bayesian pcfg (probabilistic context free grammar), but I’m getting a warning re: unable to find start values.
The model code is here:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
from pyro import poutine
from torch.distributions.utils import lazy_property
import math
# Define grammar rules and their probabilities
grammar = {
"S": [(0.8, ("NP", "VP")), (0.2, ("VP",))],
"NP": [(0.6, ("Det", "N")), (0.4, ("N",))],
"VP": [(0.7, ("V", "NP")), (0.3, ("V",))],
"Det": [(0.6, ("the",)), (0.4, ("a",))],
"N": [(0.5, ("dog",)), (0.5, ("cat",))],
"V": [(1.0, ("chased",))]
}
# Functions to generate rules from the grammar
def generate_rules(grammar):
rules = []
for lhs, rhs_list in grammar.items():
for probability, rhs in rhs_list:
rules.append((lhs, rhs, probability))
return rules
rules = generate_rules(grammar)
def logsumexp(x, y):
m = max(x, y)
return m + math.log(math.exp(x - m) + math.exp(y - m))
def cky_parse(sentence, rules):
n = len(sentence)
chart = [[{} for _ in range(n)] for _ in range(n)]
# Fill the chart with terminal rules
for i, word in enumerate(sentence):
for rule in rules:
lhs, rhs, log_probability = rule
if rhs == (word,):
chart[i][i][lhs] = log_probability
# Apply nonterminal rules
for span in range(2, n + 1):
for start in range(n - span + 1):
end = start + span - 1
for split in range(start, end):
for rule in rules:
lhs, rhs, log_probability = rule
if len(rhs) == 2:
lhs1, lhs2 = rhs
log_prob = chart[start][split].get(lhs1, float('-inf')) + chart[split + 1][end].get(lhs2, float('-inf')) + log_probability
chart[start][end][lhs] = logsumexp(chart[start][end].get(lhs, float('-inf')), log_prob)
return chart
def pcfg_pyro_model(sentence, rules, updated_grammar):
# Parse the sentence with the updated grammar
chart = cky_parse(sentence, generate_rules(updated_grammar))
# Target variable: the probability of generating the sentence
target = pyro.sample("target", dist.Exponential(1.0))
target_obs = chart[0][-1].get("S", 0)
pyro.factor("obs", math.log(target_obs) - math.log(target))
def run_pcfg_inference(sentence, rules):
# Initialize the rule probabilities using a ParamStoreDict
param_store = pyro.get_param_store()
for i, rule in enumerate(rules):
param_name = f"rule_prob_{i}"
if param_name not in param_store:
param_store[param_name] = pyro.param(param_name, torch.tensor(rule[2]), constraint=dist.constraints.positive)
# Update the grammar with the sampled rule probabilities
updated_grammar = {nt: [(pyro.param(f"rule_prob_{i}"), r[1]) for i, r in enumerate(rules) if r[0] == nt] for nt in
grammar.keys()}
# Run inference using NUTS and MCMC
nuts_kernel = NUTS(lambda: pcfg_pyro_model(sentence, rules, updated_grammar))
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run()
# Inspect the results
mcmc.summary()
# Example sentence to parse
sentence = ["the", "dog", "chased"]
# Perform inference
run_pcfg_inference(sentence, rules)
And the error I get is this:
Warmup: 0%| | 0/1500 [00:00, ?it/s]Traceback (most recent call last):
File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\api.py", line 563, in run
for x, chain_id in self.sampler.run(*args, **kwargs):
File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\api.py", line 223, in run
for sample in _gen_samples(
File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\api.py", line 144, in _gen_samples
kernel.setup(warmup_steps, *args, **kwargs)
File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\hmc.py", line 345, in setup
self._initialize_model_properties(args, kwargs)
File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\hmc.py", line 279, in _initialize_model_properties
init_params, potential_fn, transforms, trace = initialize_model(
File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\util.py", line 468, in initialize_model
initial_params = _find_valid_initial_params(
File "G:\My Drive\MIT\Research\venv\lib\site-packages\pyro\infer\mcmc\util.py", line 365, in _find_valid_initial_params
raise ValueError(
ValueError: Model specification seems incorrect - cannot find valid initial params.
Process finished with exit code 1
It seems like the model compiles okay, but there’s something about start values that it’s not liking. Any help would be very much appreciated!