How to implement Left-to-Right HMM from example HMM codes

I’m trying to modify the example codes of HMM (Example: Hidden Markov Models — Pyro Tutorials 1.7.0 documentation) to implement the Left-to-Right HMM ( Only the transitions from left to right are allowed ). I only need to restrict the transition direction. But I can’t come up with an appropriate solution.
This is the example HMM model ( Ergodic HMM ) in Example: Hidden Markov Models — Pyro Tutorials 1.7.0 documentation

def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with poutine.mask(mask=include_prior):
        # Our prior on transition probabilities will be:
        # stay in the same state with 90% probability; uniformly jump to another
        # state with 10% probability.
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
        # We put a weak prior on the conditional probability of a tone sounding.
        # We know that on average about 4 of 88 tones are active, so we'll set a
        # rough weak prior of 10% of the notes being active at any one time.
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
        )
    # In this first model we'll sequentially iterate over sequences in a
    # minibatch; this will make it easy to reason about tensor shapes.
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            # On the next line, we'll overwrite the value of x with an updated
            # value. If we wanted to record all x values, we could instead
            # write x[t] = pyro.sample(...x[t-1]...).
            x = pyro.sample(
                "x_{}_{}".format(i, t),
                dist.Categorical(probs_x[x]),
                infer={"enumerate": "parallel"},
            )
            with tones_plate:
                pyro.sample(
                    "y_{}_{}".format(i, t),
                    dist.Bernoulli(probs_y[x.squeeze(-1)]),
                    obs=sequence[t],
                )

I tried to restrict the transition from right to left by modifying the probs_x, but I failed.
I changed the probs_x to the transition probability from a current state to the same state or next state by tensor([[0.9,0.1],[0.9,0.1],[0.9,0.1]…]])

    probs_x = pyro.sample(
                "probs_x",
                dist.Dirichlet(torch.tensor([100.,1.]).repeat(args.hidden_dim,1)).to_event(1),
            )

And also I modified the for-loop region so that the transition became relative to the current state and restrict the transition from the final state to out of the index.
for t in pyro.markov(range(length)):

    probs_x[args.hidden_dim-1,1] = 0
    x = x + pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x]),
                    infer={"enumerate": "parallel"},
                )

But these changes occurred IndexError in the later svi.step function.

<ipython-input-76-6cf1bcce7519> in model_1(sequences, lengths, args, batch_size, include_prior)
     30                 x = x + pyro.sample(
     31                     "x_{}".format(t),
---> 32                     dist.Categorical(probs_x[x]),
     33                     infer={"enumerate": "parallel"},
     34                 )

IndexError: index 11 is out of bounds for dimension 0 with size 11

And I couldn’t solve this.
If you know how to implement Left-to-Right HMM, please tell me.
I’m sorry if this is an elementary question. Please let me know If I’m missing any information in my question.

This is the whole code of the example HMM (and my failed Left-to-Right version) .

import argparse
import logging
import sys
import torch
import torch.nn as nn
from torch.distributions import constraints

import pyro
import pyro.contrib.examples.polyphonic_data_loader as poly
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.DEBUG)
# Add another handler for logging debugging events (e.g. for profiling)
# in a separate stream that can be captured.
log = logging.getLogger()
debug_handler = logging.StreamHandler(sys.stdout)
debug_handler.setLevel(logging.DEBUG)
debug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)
log.addHandler(debug_handler)

def model_ergodic(sequences, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with poutine.mask(mask=include_prior):
        # Our prior on transition probabilities will be:
        # stay in the same state with 90% probability; uniformly jump to another
        # state with 10% probability.
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )

        # We put a weak prior on the conditional probability of a tone sounding.
        # We know that on average about 4 of 88 tones are active, so we'll set a
        # rough weak prior of 10% of the notes being active at any one time.
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
        )

    # In this first model we'll sequentially iterate over sequences in a
    # minibatch; this will make it easy to reason about tensor shapes.
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            # On the next line, we'll overwrite the value of x with an updated
            # value. If we wanted to record all x values, we could instead
            # write x[t] = pyro.sample(...x[t-1]...).
            x = pyro.sample(
                "x_{}_{}".format(i, t),
                dist.Categorical(probs_x[x]),
                infer={"enumerate": "parallel"},
            )
            with tones_plate:
                pyro.sample(
                    "y_{}_{}".format(i, t),
                    dist.Bernoulli(probs_y[x.squeeze(-1)]),
                    obs=sequence[t],
                )

def model_left_to_right(sequences, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(torch.tensor([100.,1.]).repeat(args.hidden_dim,1)).to_event(1),
        )

        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
        )

    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            probs_x[args.hidden_dim-1,1] = 0
            x = x + pyro.sample(
                "x_{}_{}".format(i,t),
                dist.Categorical(probs_x[x]),
                infer={"enumerate": "parallel"},
            )
            with tones_plate:
                pyro.sample(
                    "y_{}_{}".format(i, t),
                    dist.Bernoulli(probs_y[x.squeeze(-1)]),
                    obs=sequence[t],
                )

models = {
    name[len("model_") :]: model
    for name, model in globals().items()
    if name.startswith("model_")
}

model_name = 'ergodic' # 'left_to_right'
model = models[model_name]
logging.info("Loading data")

data = poly.load_data(poly.JSB_CHORALES)

logging.info("-" * 40)
logging.info(
    "Training {} on {} sequences".format(
        model.__name__, len(data["train"]["sequences"])
    )
)

sequences = data["train"]["sequences"]
lengths = data["train"]["sequence_lengths"]
# find all the notes that are present at least once in the training set
present_notes = (sequences == 1).sum(0).sum(0) > 0
# remove notes that are never played (we remove 37/88 notes)
sequences = sequences[..., present_notes]
num_observations = float(lengths.sum())
pyro.set_rng_seed(1)
pyro.clear_param_store()
guide = AutoDelta(
    poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))
)

parser = argparse.ArgumentParser(
    description="MAP Baum-Welch learning Bach Chorales"
)
parser.add_argument("-n", "--num-steps", default=50, type=int)
parser.add_argument("-b", "--batch-size", default=8, type=int)
parser.add_argument("-d", "--hidden-dim", default=16, type=int)
parser.add_argument("-nn", "--nn-dim", default=48, type=int)
parser.add_argument("-nc", "--nn-channels", default=2, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.05, type=float)
parser.add_argument("-t", "--truncate", type=int)
parser.add_argument("-p", "--print-shapes", action="store_true")
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--jit", action="store_true")
parser.add_argument("--time-compilation", action="store_true")
parser.add_argument("-rp", "--raftery-parameterization", action="store_true")
parser.add_argument(
    "--tmc",
    action="store_true",
    help="Use Tensor Monte Carlo instead of exact enumeration "
    "to estimate the marginal likelihood. You probably don't want to do this, "
    "except to see that TMC makes Monte Carlo gradient estimation feasible "
    "even with very large numbers of non-reparametrized variables.",
)
args = parser.parse_args(args=[])
first_available_dim = -2
guide_trace = poutine.trace(guide).get_trace(
    sequences, lengths, args=args, batch_size=args.batch_size
)
model_trace = poutine.trace(
    poutine.replay(poutine.enum(model, first_available_dim), guide_trace)
).get_trace(sequences, lengths, args=args, batch_size=args.batch_size)
logging.info(model_trace.format_shapes())
# Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
# All of our models have two plates: "data" and "tones".
optim = Adam({"lr": args.learning_rate})
Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
elbo = Elbo(
    max_plate_nesting=1,
    strict_enumeration_warning=True,
    jit_options={"time_compilation": args.time_compilation},
)
svi = SVI(model, guide, optim, elbo)
# We'll train on small minibatches.
logging.info("Step\tLoss")
for step in range(args.num_steps):
    loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size)
    logging.info("{: >5d}\t{}".format(step, loss / num_observations))

the most straightforward way to do this is to first “demote” probs_x to a parameter:

probs_x = pyro.param("probs_x",
    0.9 * torch.eye(args.hidden_dim) + 0.1,
    constraint=dist.constraints.simplex
    )

and then mask out probs_x before you use it in any downstream computation

# not sure what the exact form of the mask you want is 
# but it's probably something like this:
my_left_right_mask = torch.ones(args.hidden_dim, args.hidden_dim).tril(0)
probs_x = pyro.param("probs_x", ....)
probs_x = my_left_right_mask * probs_x

Thank you very much!! It works!!
I’m really grateful to you!!