Reinforcement Learning as Variational Inference in Pyro

Dear Pyro Developers,

I implement two classic RL algorithms (REINFORCE, Actor-Critic) in Pyro, the learning task is finally solved by Pyro’s built-in SVI inference engine. I was wondering whether you are interested in that we turn this into a Pyro tutorial.

Abstract of my Technical Report

The main focus of this project is to empirically evaluate mainly Merlin (Levine, 2018), Virel (Fellows et al., 2019), and the Pyro1 (Bingham et al., 2019) probabilistic programming language (PPL). Merlin and Virel are two frameworks aimed at solving the problem of maximum entropy reinforcement learning (MERL) via variational inference (VI). Levine (2018), and Fellows et al. (2019) have already shown that reinforcement learning (RL) tasks could be reformed into probabilistic inference tasks in general, which in principle allows us to bring to bear a wide array of approximate inference tools. Deep universal Probabilistic programming languages, the inference engine of which has stochastic variational inference(SVI) algorithms (Wingate and Weber, 2013) builtin, could solve such inference problems; however, were not applied.

I developed rl.Pyro as a proof-of-concept, which implements two of the most classic policy-based reinforcement learning algorithms, namely REINFORCE (Thomas and Brunskill, 2017; Shi et al., 2019) and Actor-Critic (Mnih et al., 2016; Haarnoja et al., 2018), in Pyro, a deep universal PPL implemented in Python and supported by PyTorch (Paszke et al., 2019) on the backend, and empirically evaluated these algorithms on the CartPole environment from the OpenAI Gym benchmark (Brockman et al., 2016).

The experimental results validated the theoretical connection between MERL and VI. More importantly, The experimental results confirmed that (1) Pyro is express enough to implement policy-based RL algorithms, (2) the performance of the Pyro version of the algorithm is satisfying, and (3) modeling and training are better decoupled using Pyro.

Code Fragment (REINFORCE in Pyro)

def guide(env=None, trajectory=None):
    pyro.module("policy_net", policy_net)
    S, A, R, D, step = [], [], [], [], 0

    obs = env.reset()
    done = False
    while not done:
        S.append(obs)
        D.append(done)
        action = pyro.sample(
            f"action_{step}",
            pyro.distributions.Categorical(
                policy_net(obs)
            )
        ).item()
        obs, reward, done, _ = env.step(action)
        A.append(action)
        R.append(reward)
        step += 1
    S.append(obs)
    D.append(done)

    # send the trajectory to the model program
    trajectory["S"] = S
    trajectory["A"] = A
    trajectory["R"] = R
    trajectory["D"] = D

def model(env=None, trajectory=None):
    S, R = trajectory["S"], trajectory["R"]
    for step in pyro.plate("trajectory", len(R)):
        action = pyro.sample(
            f"action_{step}",
            pyro.distributions.Categorical(
                torch.ones(ACT_N) / ACT_N
            )
        )
        pyro.factor(
            f"discount_{step}",
            torch.log(GAMMA)
        )
        pyro.factor(
            f"reward_{step}",
            R[step] / TEMPERATURE
        )

def train():
    adma = pyro.optim.Adam({"lr": LEARNING_RATE})
    svi  = pyro.infer.SVI(
        model, guide, adma, 
        loss = pyro.infer.Trace_ELBO()
    )
    pyro.clear_param_store()
    for epi in range(EPISODES):
        svi.step(env, trajectory={})