AttributeError: __enter__ in forecasting_iii

I am trying to adapt my data to Forecasting III: hierarchical models β€” Pyro Tutorials 1.8.4 documentation .

My data is salecount of various products in many stores .
I have reshape the szie to torch.Size([44, 103, 671, 1]) , means:
44 stores, 103 products , 671 days salecount .

This shape fit the input in tutorials . I change some plates name in tutorials:

class Model2(ForecastingModel):
    def model(self, zero_data, covariates):
        num_stores, num_products, duration, one = zero_data.shape

        # We construct plates once so we can reuse them later. We ensure they don't collide by
        # specifying different dim args for each: -3, -2, -1. Note the time_plate is dim=-1.
        stores_plate = pyro.plate("stores", num_stores, dim=-3)
        products_plate = pyro.plate("products", num_products, dim=-2)
        day_of_week_plate = pyro.plate("day_of_week", 7, dim=-1)

        # Let's model the time-dependent part with only O(num_stations * duration) 倍杂度 many
        # parameters, rather than the full possible O(num_stations ** 2 * duration) data size.
        drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
        drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5))
        with stores_plate:
            with day_of_week_plate:
                stores_seasonal = pyro.sample("stores_seasonal", dist.Normal(0, 5))
        with products_plate:
            with day_of_week_plate:
                product_seasonal = pyro.sample("product_seasonal", dist.Normal(0, 5))
            with self.time_plate:
                with poutine.reparam(config={"drift": LocScaleReparam()}):
                    with poutine.reparam(config={"drift": SymmetricStableReparam()}):
                        drift = pyro.sample("drift",
                                            dist.Stable(drift_stability, 0, drift_scale))
                        
        # δΈεŒι—¨εΊ— ε―Ή δΈεŒε•†ε“ ηš„ 偏ε₯½
        with stores_plate, products_plate:
            pairwise = pyro.sample("pairwise", dist.Normal(0, 1))

        # Outside of the time plate we can now form the prediction.
        seasonal = stores_seasonal + product_seasonal  # Note this broadcasts.
        seasonal = periodic_repeat(seasonal, duration, dim=-1)  # η»„θ£…ζ—Άι—΄ε‘¨ζœŸ
        motion = drift.cumsum(dim=-1)  # A Levy stable motion to model shocks.
        prediction = motion + seasonal + pairwise

        # We will decompose the noise scale parameter into
        # an origin-local and a destination-local component.
        with stores_plate:
            stores_plate = pyro.sample("origin_scale", dist.LogNormal(-5, 5))
        with products_plate:
            products_plate = pyro.sample("destin_scale", dist.LogNormal(-5, 5))
        scale = stores_plate + products_plate

        # At this point our prediction and scale have shape (50, 50, duration) and (50, 50, 1)
        # respectively, but we want them to have shape (50, 50, duration, 1) to satisfy the
        # Forecaster requirements.
        scale = scale.unsqueeze(-1)
        prediction = prediction.unsqueeze(-1)

        # Finally we construct a noise distribution and call the .predict() method.
        # Note that predict must be called inside the origin and destination plates.
        noise_dist = dist.Normal(0, scale)
        with stores_plate, products_plate:
            self.predict(noise_dist, prediction)
            
            
%%time
pyro.set_rng_seed(1)
pyro.clear_param_store()
covariates = torch.zeros(msc.size(-2), 0)  
forecaster = Forecaster(Model2(), msc, covariates,
                        learning_rate=0.1, learning_rate_decay=1, num_steps=501, log_every=50)
for name, value in forecaster.guide.median().items():
    if value.numel() == 1:
        print("{} = {:0.4g}".format(name, value.item()))

But I got error

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<timed exec> in <module>

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py in __init__(self, model, data, covariates, guide, init_loc_fn, init_scale, create_plates, optim, learning_rate, betas, learning_rate_decay, clip_norm, time_reparam, dct_gradients, subsample_aware, num_steps, num_particles, vectorize_particles, warm_start, log_every)
    287         elbo = Trace_ELBO(num_particles=num_particles,
    288                           vectorize_particles=vectorize_particles)
--> 289         elbo._guess_max_plate_nesting(model, guide, (data, covariates), {})
    290         elbo.max_plate_nesting = max(elbo.max_plate_nesting, 1)  # force a time plate
    291 

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/infer/elbo.py in _guess_max_plate_nesting(self, model, guide, args, kwargs)
     91         # Ignore validation to allow model-enumerated sites absent from the guide.
     92         with poutine.block():
---> 93             guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
     94             model_trace = poutine.trace(
     95                 poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    185         Calls this poutine and returns its trace instead of the function's return value.
    186         """
--> 187         self(*args, **kwargs)
    188         return self.msngr.get_trace()

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    163                                       args=args, kwargs=kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError) as e:
    167                 exc_type, exc_value, traceback = sys.exc_info()

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    411     def __call__(self, *args, **kwargs):
    412         with self._pyro_context:
--> 413             return super().__call__(*args, **kwargs)
    414 
    415     def __getattr__(self, name):

~/anaconda3/envs/dl/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py in forward(self, *args, **kwargs)
    483         # if we've never run the model before, do so now so we can inspect the model structure
    484         if self.prototype_trace is None:
--> 485             self._setup_prototype(*args, **kwargs)
    486 
    487         plates = self._create_plates(*args, **kwargs)

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py in _setup_prototype(self, *args, **kwargs)
    438 
    439     def _setup_prototype(self, *args, **kwargs):
--> 440         super()._setup_prototype(*args, **kwargs)
    441 
    442         self._event_dims = {}

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py in _setup_prototype(self, *args, **kwargs)
    156         # run the model so we can inspect its structure
    157         model = poutine.block(self.model, prototype_hide_fn)
--> 158         self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
    159         if self.master is not None:
    160             self.master()._check_prototype(self.prototype_trace)

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    185         Calls this poutine and returns its trace instead of the function's return value.
    186         """
--> 187         self(*args, **kwargs)
    188         return self.msngr.get_trace()

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    163                                       args=args, kwargs=kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError) as e:
    167                 exc_type, exc_value, traceback = sys.exc_info()

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    411     def __call__(self, *args, **kwargs):
    412         with self._pyro_context:
--> 413             return super().__call__(*args, **kwargs)
    414 
    415     def __getattr__(self, name):

~/anaconda3/envs/dl/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py in forward(self, data, covariates)
    172             self._forecast = None
    173 
--> 174             self.model(zero_data, covariates)
    175 
    176             assert self._forecast is not None, ".predict() was not called by .model()"

<ipython-input-202-35b921453dff> in model(self, zero_data, covariates)
     52         # Note that predict must be called inside the origin and destination plates.
     53         noise_dist = dist.Normal(0, scale)
---> 54         with stores_plate, products_plate:
     55             self.predict(noise_dist, prediction)

AttributeError: __enter__

pyro: 1.5.1

You’re getting this error because you are redefining the variables stores_plate and products_plate a few lines above to refer to scale samples rather than plates:

with stores_plate:
    # incorrect, change this to origin_scale = ...
    stores_plate = pyro.sample("origin_scale", dist.LogNormal(-5, 5))
1 Like

OH, my fault , thank you !