AttributeError: __enter__ in forecasting_iii

I am trying to adapt my data to Forecasting III: hierarchical models β€” Pyro Tutorials 1.8.4 documentation .

My data is salecount of various products in many stores .
I have reshape the szie to torch.Size([44, 103, 671, 1]) , means:
44 stores, 103 products , 671 days salecount .

This shape fit the input in tutorials . I change some plates name in tutorials:

class Model2(ForecastingModel):
    def model(self, zero_data, covariates):
        num_stores, num_products, duration, one = zero_data.shape

        # We construct plates once so we can reuse them later. We ensure they don't collide by
        # specifying different dim args for each: -3, -2, -1. Note the time_plate is dim=-1.
        stores_plate = pyro.plate("stores", num_stores, dim=-3)
        products_plate = pyro.plate("products", num_products, dim=-2)
        day_of_week_plate = pyro.plate("day_of_week", 7, dim=-1)

        # Let's model the time-dependent part with only O(num_stations * duration) 倍杂度 many
        # parameters, rather than the full possible O(num_stations ** 2 * duration) data size.
        drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
        drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5))
        with stores_plate:
            with day_of_week_plate:
                stores_seasonal = pyro.sample("stores_seasonal", dist.Normal(0, 5))
        with products_plate:
            with day_of_week_plate:
                product_seasonal = pyro.sample("product_seasonal", dist.Normal(0, 5))
            with self.time_plate:
                with poutine.reparam(config={"drift": LocScaleReparam()}):
                    with poutine.reparam(config={"drift": SymmetricStableReparam()}):
                        drift = pyro.sample("drift",
                                            dist.Stable(drift_stability, 0, drift_scale))
        # δΈεŒι—¨εΊ— ε―Ή δΈεŒε•†ε“ ηš„ 偏ε₯½
        with stores_plate, products_plate:
            pairwise = pyro.sample("pairwise", dist.Normal(0, 1))

        # Outside of the time plate we can now form the prediction.
        seasonal = stores_seasonal + product_seasonal  # Note this broadcasts.
        seasonal = periodic_repeat(seasonal, duration, dim=-1)  # η»„θ£…ζ—Άι—΄ε‘¨ζœŸ
        motion = drift.cumsum(dim=-1)  # A Levy stable motion to model shocks.
        prediction = motion + seasonal + pairwise

        # We will decompose the noise scale parameter into
        # an origin-local and a destination-local component.
        with stores_plate:
            stores_plate = pyro.sample("origin_scale", dist.LogNormal(-5, 5))
        with products_plate:
            products_plate = pyro.sample("destin_scale", dist.LogNormal(-5, 5))
        scale = stores_plate + products_plate

        # At this point our prediction and scale have shape (50, 50, duration) and (50, 50, 1)
        # respectively, but we want them to have shape (50, 50, duration, 1) to satisfy the
        # Forecaster requirements.
        scale = scale.unsqueeze(-1)
        prediction = prediction.unsqueeze(-1)

        # Finally we construct a noise distribution and call the .predict() method.
        # Note that predict must be called inside the origin and destination plates.
        noise_dist = dist.Normal(0, scale)
        with stores_plate, products_plate:
            self.predict(noise_dist, prediction)
covariates = torch.zeros(msc.size(-2), 0)  
forecaster = Forecaster(Model2(), msc, covariates,
                        learning_rate=0.1, learning_rate_decay=1, num_steps=501, log_every=50)
for name, value in
    if value.numel() == 1:
        print("{} = {:0.4g}".format(name, value.item()))

But I got error

AttributeError                            Traceback (most recent call last)
<timed exec> in <module>

pyro: 1.5.1

You’re getting this error because you are redefining the variables stores_plate and products_plate a few lines above to refer to scale samples rather than plates:

with stores_plate:
    # incorrect, change this to origin_scale = ...
    stores_plate = pyro.sample("origin_scale", dist.LogNormal(-5, 5))
OH, my fault , thank you !