I am trying to adapt my data to Forecasting III: hierarchical models β Pyro Tutorials 1.8.4 documentation .
My data is salecount of various products in many stores .
I have reshape the szie to torch.Size([44, 103, 671, 1])
, means:
44 stores, 103 products , 671 days salecount .
This shape fit the input in tutorials . I change some plates name in tutorials:
class Model2(ForecastingModel):
def model(self, zero_data, covariates):
num_stores, num_products, duration, one = zero_data.shape
# We construct plates once so we can reuse them later. We ensure they don't collide by
# specifying different dim args for each: -3, -2, -1. Note the time_plate is dim=-1.
stores_plate = pyro.plate("stores", num_stores, dim=-3)
products_plate = pyro.plate("products", num_products, dim=-2)
day_of_week_plate = pyro.plate("day_of_week", 7, dim=-1)
# Let's model the time-dependent part with only O(num_stations * duration) ε€ζεΊ¦ many
# parameters, rather than the full possible O(num_stations ** 2 * duration) data size.
drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5))
with stores_plate:
with day_of_week_plate:
stores_seasonal = pyro.sample("stores_seasonal", dist.Normal(0, 5))
with products_plate:
with day_of_week_plate:
product_seasonal = pyro.sample("product_seasonal", dist.Normal(0, 5))
with self.time_plate:
with poutine.reparam(config={"drift": LocScaleReparam()}):
with poutine.reparam(config={"drift": SymmetricStableReparam()}):
drift = pyro.sample("drift",
dist.Stable(drift_stability, 0, drift_scale))
# δΈει¨εΊ ε―Ή δΈεεε η εε₯½
with stores_plate, products_plate:
pairwise = pyro.sample("pairwise", dist.Normal(0, 1))
# Outside of the time plate we can now form the prediction.
seasonal = stores_seasonal + product_seasonal # Note this broadcasts.
seasonal = periodic_repeat(seasonal, duration, dim=-1) # η»θ£
ζΆι΄ε¨ζ
motion = drift.cumsum(dim=-1) # A Levy stable motion to model shocks.
prediction = motion + seasonal + pairwise
# We will decompose the noise scale parameter into
# an origin-local and a destination-local component.
with stores_plate:
stores_plate = pyro.sample("origin_scale", dist.LogNormal(-5, 5))
with products_plate:
products_plate = pyro.sample("destin_scale", dist.LogNormal(-5, 5))
scale = stores_plate + products_plate
# At this point our prediction and scale have shape (50, 50, duration) and (50, 50, 1)
# respectively, but we want them to have shape (50, 50, duration, 1) to satisfy the
# Forecaster requirements.
scale = scale.unsqueeze(-1)
prediction = prediction.unsqueeze(-1)
# Finally we construct a noise distribution and call the .predict() method.
# Note that predict must be called inside the origin and destination plates.
noise_dist = dist.Normal(0, scale)
with stores_plate, products_plate:
self.predict(noise_dist, prediction)
%%time
pyro.set_rng_seed(1)
pyro.clear_param_store()
covariates = torch.zeros(msc.size(-2), 0)
forecaster = Forecaster(Model2(), msc, covariates,
learning_rate=0.1, learning_rate_decay=1, num_steps=501, log_every=50)
for name, value in forecaster.guide.median().items():
if value.numel() == 1:
print("{} = {:0.4g}".format(name, value.item()))
But I got error
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<timed exec> in <module>
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py in __init__(self, model, data, covariates, guide, init_loc_fn, init_scale, create_plates, optim, learning_rate, betas, learning_rate_decay, clip_norm, time_reparam, dct_gradients, subsample_aware, num_steps, num_particles, vectorize_particles, warm_start, log_every)
287 elbo = Trace_ELBO(num_particles=num_particles,
288 vectorize_particles=vectorize_particles)
--> 289 elbo._guess_max_plate_nesting(model, guide, (data, covariates), {})
290 elbo.max_plate_nesting = max(elbo.max_plate_nesting, 1) # force a time plate
291
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/infer/elbo.py in _guess_max_plate_nesting(self, model, guide, args, kwargs)
91 # Ignore validation to allow model-enumerated sites absent from the guide.
92 with poutine.block():
---> 93 guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
94 model_trace = poutine.trace(
95 poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
185 Calls this poutine and returns its trace instead of the function's return value.
186 """
--> 187 self(*args, **kwargs)
188 return self.msngr.get_trace()
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
163 args=args, kwargs=kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError) as e:
167 exc_type, exc_value, traceback = sys.exc_info()
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
411 def __call__(self, *args, **kwargs):
412 with self._pyro_context:
--> 413 return super().__call__(*args, **kwargs)
414
415 def __getattr__(self, name):
~/anaconda3/envs/dl/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py in forward(self, *args, **kwargs)
483 # if we've never run the model before, do so now so we can inspect the model structure
484 if self.prototype_trace is None:
--> 485 self._setup_prototype(*args, **kwargs)
486
487 plates = self._create_plates(*args, **kwargs)
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py in _setup_prototype(self, *args, **kwargs)
438
439 def _setup_prototype(self, *args, **kwargs):
--> 440 super()._setup_prototype(*args, **kwargs)
441
442 self._event_dims = {}
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py in _setup_prototype(self, *args, **kwargs)
156 # run the model so we can inspect its structure
157 model = poutine.block(self.model, prototype_hide_fn)
--> 158 self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
159 if self.master is not None:
160 self.master()._check_prototype(self.prototype_trace)
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
185 Calls this poutine and returns its trace instead of the function's return value.
186 """
--> 187 self(*args, **kwargs)
188 return self.msngr.get_trace()
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
163 args=args, kwargs=kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError) as e:
167 exc_type, exc_value, traceback = sys.exc_info()
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
411 def __call__(self, *args, **kwargs):
412 with self._pyro_context:
--> 413 return super().__call__(*args, **kwargs)
414
415 def __getattr__(self, name):
~/anaconda3/envs/dl/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py in forward(self, data, covariates)
172 self._forecast = None
173
--> 174 self.model(zero_data, covariates)
175
176 assert self._forecast is not None, ".predict() was not called by .model()"
<ipython-input-202-35b921453dff> in model(self, zero_data, covariates)
52 # Note that predict must be called inside the origin and destination plates.
53 noise_dist = dist.Normal(0, scale)
---> 54 with stores_plate, products_plate:
55 self.predict(noise_dist, prediction)
AttributeError: __enter__
pyro: 1.5.1