week-05.ipynb
{"cells":[{"cell_type":"code","execution_count":1,"source":["# INFO: IPython extension for auto reloading modified custom packages\"\"\"\n","%load_ext autoreload\n","%autoreload 2\n","\n","\n","# INFO: IPython extension for package/system spec output\n","%load_ext watermark\n","\n","\n","# INFO: Core imports for practicaly any data science activity\n","import numpy as np\n","import pandas as pd\n","\n","\n","# INFO: Customize settings for Pandas\n","pd.options.display.max_columns = 500\n","pd.options.display.max_rows = 500\n","pd.options.display.max_colwidth = 500\n","\n","# INFO: to display dataframes as tables on call\n","from IPython.display import display\n","\n","\n","# INFO: Plotting setup (matplotlib is only for compatibility with legacy code)\n","# import matplotlib.pyplot as plt\n","%matplotlib inline\n","import plotly.io as pio\n","import plotly.express as px\n","import plotly.graph_objects as go\n","\n","\n","# INFO: Customize plotting backend for Pandas (matplotlib for compat)\n","pd.options.plotting.backend = \"plotly\"\n","# pd.options.plotting.backend = \"matplotlib\"\n","\n","\n","# INFO: Customize Plotly theme\n","pio.templates.default = \"plotly_dark\"\n","# pio.templates.default = \"plotly\"\n","\n","\n","# INFO: Logging setup (replaces 'print' in development & seamlessly transitions to production code)\n","import logging\n","import sys\n","\n","root = logging.getLogger()\n","root.setLevel(logging.INFO)\n","\n","handler = logging.StreamHandler(sys.stdout)\n","handler.setLevel(logging.INFO)\n","formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n","handler.setFormatter(formatter)\n","root.addHandler(handler)\n","\n","# INFO: Use logging like this:\n","logging.info(\"Logging is set!\")\n","\n","\n","# INFO: Call for package/system spec output\n","%watermark --iversions"],"outputs":[{"output_type":"stream","name":"stdout","text":["The autoreload extension is already loaded. To reload it, use:\n"," %reload_ext autoreload\n","2022-04-25 23:58:24,713 - root - INFO - Logging is set!\n","matplotlib: 3.5.1\n","logging : 0.5.1.2\n","numpy : 1.22.3\n","plotly : 5.6.0\n","pandas : 1.4.1\n","sys : 3.10.4 (main, Mar 25 2022, 00:00:00) [GCC 11.2.1 20220127 (Red Hat 11.2.1-9)]\n","\n"]}],"metadata":{}},{"cell_type":"code","execution_count":2,"source":["# INFO: PPL specific imports\n","\n","import jax.numpy as jnp\n","from jax import random\n","\n","import numpyro\n","import numpyro.distributions as dist\n","import numpyro.optim as optim\n","from numpyro.infer import SVI, Trace_ELBO, Predictive\n","from numpyro.infer import MCMC, NUTS\n","from numpyro.infer.autoguide import AutoLaplaceApproximation, AutoNormal\n","\n","from jax import lax, random\n","from jax.scipy.special import expit\n","\n","import arviz as az"],"outputs":[],"metadata":{}},{"cell_type":"code","execution_count":3,"source":["data_uri = \"https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/NWOGrants.csv\"\n","df_dev = pd.read_csv(data_uri, sep=\";\")\n","df_dev.head()\n","df_dev[\"gender\"] = df_dev[\"gender\"] == \"m\"\n","df_dev[\"gender\"] = df_dev[\"gender\"].astype(int)"],"outputs":[],"metadata":{}},{"cell_type":"code","execution_count":4,"source":["# INFO: total effect for starters; DAG synced with answers (coz my personal DAG is different)\n","\n","\n","def model(data: pd.DataFrame, observed=True):\n"," applications = data[\"applications\"].values\n"," awards = data[\"awards\"].values\n","\n"," discipline = data[\"discipline\"].values\n"," discipline_card = np.unique(discipline).shape[0]\n"," gender = data[\"gender\"].values\n"," gender_card = np.unique(gender).shape[0]\n","\n"," alpha_gender = numpyro.sample(\"alpha_gender\", dist.Normal(-1, 1).expand([gender_card]))\n"," logit_p = numpyro.deterministic(\"logit_p\", alpha_gender[gender_card])\n","\n"," # alpha_gender = numpyro.sample(\"alpha_gender\", dist.Normal(-1, 1).expand([gender_card, discipline_card]))\n"," # logit_p = numpyro.deterministic(\"logit_p\", alpha_gender[gender_card, discipline_card])\n","\n"," numpyro.sample(\"awards\", dist.Binomial(total_count=applications, logits=logit_p), obs=awards if observed else None)\n"," # numpyro.sample(\"awards\", dist.Binomial(total_count=applications, probs=logit_p), obs=awards if observed else None)\n"," # numpyro.sample(\n"," # \"awards\", dist.Binomial(total_count=applications, logits=expit(logit_p)), obs=awards if observed else None\n"," # )\n","\n"," # dist.Binomial(applications, logits=logits)\n","\n"," # with numpyro.plate(\"applications\", applications):\n"," # alpha_gender = numpyro.sample(\"alpha_gender\", dist.Normal(-1, 1).expand([gender_card]))\n"," # logit_p = numpyro.deterministic(\"logit_p\", alpha_gender[gender_card])\n","\n"," # numpyro.sample(\"obs\", dist.Bernoulli(logit_p), obs=data)\n","\n","\n","kernel = NUTS(model)\n","mcmc = MCMC(\n"," kernel,\n"," # num_warmup=500,\n"," num_warmup=1000,\n"," # num_warmup=2000,\n"," # num_samples=2000,\n"," num_samples=5000,\n"," # num_samples=10_000,\n"," num_chains=1,\n"," # num_chains=4,\n"," progress_bar=True,\n",")\n","mcmc.run(random.PRNGKey(0), df_dev)\n","samples = mcmc.get_samples()\n","\n","numpyro.diagnostics.print_summary(samples, prob=0.89, group_by_chain=False)"],"outputs":[{"output_type":"stream","name":"stdout","text":["2022-04-25 23:58:57,453 - absl - INFO - Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n","2022-04-25 23:58:57,454 - absl - INFO - Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: Host Interpreter\n","2022-04-25 23:58:57,455 - absl - INFO - Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n"]},{"output_type":"stream","name":"stderr","text":["sample: 100%|██████████| 6000/6000 [00:10<00:00, 584.37it/s, 1023 steps of size 1.63e-03. acc. prob=0.84]\n"]},{"output_type":"stream","name":"stdout","text":["\n"," mean std median 5.5% 94.5% n_eff r_hat\n","alpha_gender[0] -0.98 0.99 -0.99 -2.50 0.62 952.90 1.00\n","alpha_gender[1] -1.62 0.05 -1.62 -1.71 -1.54 586.28 1.00\n"," logit_p -1.62 0.05 -1.62 -1.71 -1.54 586.28 1.00\n","\n"]}],"metadata":{}},{"cell_type":"code","execution_count":5,"source":["az.plot_trace(az.from_numpyro(mcmc))"],"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([[<AxesSubplot:title={'center':'alpha_gender'}>,\n"," <AxesSubplot:title={'center':'alpha_gender'}>],\n"," [<AxesSubplot:title={'center':'logit_p'}>,\n"," <AxesSubplot:title={'center':'logit_p'}>]], dtype=object)"]},"metadata":{},"execution_count":5},{"output_type":"display_data","data":{"image/png":"","text/plain":["<Figure size 864x288 with 4 Axes>"]},"metadata":{"needs_background":"dark"}}],"metadata":{}}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":3},"orig_nbformat":4}}