Hierarchical models

Data analysis and machine learning

In this class you will learn about hierarchical models. Hierarchical models are defined by having one or more parameters that influence many individual measurements. Imagine if you had many individual observations from different systems (e.g., galaxies, or people, or countries) and each of those measurements were very noisy. Now let us imagine that there is some underlying parameter — or rather, a population parameter — that influences each of the observations that you have made, but in slightly different ways. For example, each galaxy you observe might share a property with all other galaxies (e.g., a rotation curve profile) but each individual galaxy might have also have it's own offset or set of parameters.

What can we say of the population parameter from just looking at one observation? Not much. What if we can fit (or better yet, draw samples from the posterior of) the population parameter for one observation, and then repeat that for another observation, and then another? Can we build up knowledge about the population parameter then? In a sense we can, but in practice The Right Thing™ to do is to sample all of the observations at once, and infer both the population parameters and the individual parameters at the same time.

That's hierarchical modelling. Hierarchical modelling is great if

  1. you have individual observations of things;
  2. each thing might have it's own unknown parameters, but there is also some common information or property about all those things (that you can describe with a set of population parameters);
  3. and it's particularly powerful if the individual measurements you have are very noisy!

Probabilstic graphical models

As stated in the syllabus: this sub-unit is designed such that the lectures will provide you breadth on a large number of topics. For this reason, in each class I will introduce you to new tools, techniques, methods, and language, which will help you at some point in the future.Or otherwise there's no point! The model we are trying to describe — a hierarchical model — can be described much more simply using a probabilistic graphical model. Graphical models are a concise way of representing a model (and data) that is interpretable by experts from many different fields! In other words, if you can draw a probabilistic graphical model then you can instantly show it to a statistician, a physicist, a machine learning expert, or a geologist with data analysis experience, et cetera, and they ought to understand how your model is structured.

Let's draw a probabilistic graphical model. First, the basics.

So our most simplistic straight line model from lesson one looks something like this:

Beautiful!Thanks to daft And stunningly clear. From this figure we can see that we have model parameters \(m\) and \(b\), that we have \(N\) data points where the \(x\) values are fixed (to the true values \(\hat{x}\)) and the \(y\)-direction uncertainties \(\sigma_{y,i}\) are fixed, and the true values \(\hat{y}_i\) depend only on \(\{\hat{x}_i,\sigma_{y,i},m,b\}\), indicating that any residuals away from whatever combination of \(m\) and \(b\) are used to produce \(y\) must be due to the uncertainty \(\sigma_{y,i}\). There is a remarkable amount of information presented here! Note that we haven't explicitly described how \(m\) and \(b\) produce \(y\), and we don't really care that much. In practice when presenting graphical models we are using them to demonstrate how everything fits together. We don't necessarily care whether you are multiplying \(m\) and \(b\) together, or taking \(m^b\), or something else equally crazy: all we care about on that is that you have some function that produces \(y\) given \(m\) and \(b\). We care more about what things are parameters, what things are fixed, what things are data, and what depends on what!

If you're interested, we produced this in Python using the following code: import daft from matplotlib import rc rc("font", family="serif", size=12) rc("text", usetex=True) # Straight line model. pgm = daft.PGM() pgm.add_node("obs", r"$\hat{y}_i$", 0, 0, observed=True) pgm.add_node("m", r"$m$", -0.5, 1) pgm.add_node("b", r"$b$", +0.5, 1) pgm.add_node("x", r"$\hat{x}_i$", -1, 0, fixed=True) pgm.add_node("y_sigma", r"$\sigma_{y,i}$", 1, 0, fixed=True) pgm.add_plate([-1.5, -0.5, 3, 1], label=r"$i=1,\ldots,N$") pgm.add_edge("m", "obs") pgm.add_edge("b", "obs") pgm.add_edge("x", "obs") pgm.add_edge("y_sigma", "obs") pgm.render()

You should consider drawing probabilstic graphical models for some of the models we considered in later classes! Drawing some will give you good intution for how to view more complex models.

Hierarchical models

A hierarchical model could be defined by having just two levels, or multiple levels. Consider this example from The Supernova Cosmology Project

or this example for inferring cosmic shear, where there are multiple overlapping sets and multiple hierarchies:

Hierarchical models of this kind can be extremely challenging to optimise and sample. Optimisation becomes difficult because there are some model parameters that will have zero derivatives (or zero impact) with respect to some data, and you can easily end up with a biased result. Sampling becomes difficult for similar reasons, in that it is clear that some parameters do not generate some parts of the data, but other parts will have a significant effect. For relatively simple hierarchical models it is usually alright to still use a Hamiltonian Monte Carlo method, but for more complex hierarchical models you may want to look into Gibbs sampling.

Example: radioactive decay

A situation has emerged where a company has been producing widgets and they have inadvertedly produced radioactive material as part of their manufacturing process, and some of that radioactive material is contained in the widgets. Here is the timeline:

The company would like you to use the data from the investigation team and determine the decay rate of the (unknown) radioactive material, \(\alpha\).

Consider this as a hierarchical Bayesian problem, and imagine that the decay lifetime of the radioactive material is described by the parameter \(\alpha\) such that the amount of material in one widget any time \(t\) is \[ N_{t} = N_0\exp\left(-\alpha\left[t_1 - t_0\right]\right) \] where \(N_0\) is the amount of material in the widget when it was manufactured at time \(t_0\). The time that the investigation team estimated the amount of material in a widget is very well known (\(t_1\)), but there are many unknown parameters here. The decay rate \(\alpha\) is what we are most interested in inferring, but for an individual widget we don't know how much material was in the widget when it was manufactured \(N_0\), or the exact time that it was manufactured \(t_0\). And while there were 100 widgets manufactured, we only have a single estimate of the amount of radioactive material present at some later time.

In a hierchical framework you would consider \(\alpha\) to be a Level 0 parameter in that it ultimately affects all your observations. You would then have the set of parameters \(\{\vec{t_{0}}, \vec{N_0}\}\) for all 100 widgets, where we only have a noisy estimate of each parameter. Indeed, for the initial amount of material for the \(i\)-th widget, \[ N_{0,i} \sim \mathcal{U}\left(0, N_{0,max}\right) \] where \(N_{0,max}\) is reported by the company, and \[ t_{0,i} \sim \mathcal{U}\left(t_{0,min}, t_{0,max}\right) \] where \(t_{0,min}\) is the start of the manufacture date for the \(i\)-th widget, and \(t_{0,max}\) is the time at the end of that day. And naturally, the estimates of the material at time \(t_1\) also have some uncertainty.

In the end our inference on \(\alpha\) must take all of this (model and observational) uncertainty into account!

Let's start by drawing a probablistic graphical model for this problem.

Let's generate some faux data for this problem so we can better understand how well we can model it. import numpy as np import matplotlib.pyplot as plt np.random.seed(0) days_to_seconds = 24 * 60 * 60 # The true decay rate. alpha = np.random.uniform() * 1e-6 N_widgets = 100 # number of widgets N_initial_max = 10 # could be in masses, moles, etc. # The company made widgets for 35 days. manufacturing_time_span = 35 * days_to_seconds # 35 days # Generate true values of N_initial and t_initial. N_initial_true = np.random.uniform(0, N_initial_max, size=(N_widgets, 1)) t_initial_true = np.random.uniform(0, manufacturing_time_span, size=(N_widgets, 1)) # The time delay between the company stopping to make widgets # and starting to measure them was 14 days, and the investigation took 90 days. t_delay = 14 * days_to_seconds measurement_time_span = 90 * days_to_seconds # The inspectors realise they can only measure the widgets once each. N_observations = 1 # The time they measured is very precise though. t_measured = manufacturing_time_span \ + t_delay \ + np.random.uniform(0, measurement_time_span, size=(N_widgets, N_observations)) # Sort in order. t_measured = np.sort(t_measured, axis=1) def N_now(N_initial, alpha, t_initial, t_now): return N_initial * np.exp(-alpha * (t_now - t_initial)) N_now_true = N_now(N_initial_true, alpha, t_initial_true, t_measured) # There is some error in the measurements of material in each widget. v = 0.10 sigma_N_measured = v * np.random.normal(0, 1, size=N_now_true.shape) N_measured = N_now_true + np.random.normal(0, 1, size=N_now_true.shape) * sigma_N_measured sigma_N_measured = np.abs(sigma_N_measured)

Now that we have generated some data, let's plot it to make sure things look sensible. import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator # Plot the data first, make sure things look sensible! fig, ax = plt.subplots(figsize=(4, 4)) for i in range(N_widgets): x = (t_measured[i] - t_initial_true[i])/days_to_seconds ax.errorbar( x, N_measured[i], xerr=0.5 * np.ones_like(x), yerr=sigma_N_measured[0], fmt="o", c="k", lw=1, ) ax.set_xlabel(r"$\Delta{}t$ / days") ax.set_ylabel(r"$N_1$ for each widget") ax.xaxis.set_major_locator(MaxNLocator(6)) ax.yaxis.set_major_locator(MaxNLocator(6)) fig.tight_layout()

OK, now let's build a model for the faux data. To do this we are going to use something called stan (and pystan, the Python interface to stan), which is a probablistic programming language that allows us to easily perform inference on complex models.Stan is not a solution for all data analysis problems. In fact the truth is quite the opposite: Stan is very well engineered for some data analysis problems, but it is probably not the right tool for most data analysis problems.

Just like the straight line model from week one, it is a good idea to start with a simple model and then build up complexity. Here we will assume that the time of manufacture is well-known. Once we are happy that is working, we can add complexity. For convenience we will first create a file called stan_utils.py which contains the following code: import os import logging import pickle import pystan as stan def load_stan_model(path, cached_path=None, recompile=False, overwrite=True, verbose=True): r""" Load a Stan model from a file. If a cached file exists, use it by default. :param path: The path of the Stan model. :param cached_path: [optional] The path of the cached Stan model. By default this will be the same as :path:, with a `.cached` extension appended. :param recompile: [optional] Recompile the model instead of using a cached version. If the cached version is different from the version in path, the model will be recompiled automatically. """ cached_path = cached_path or "{}.cached".format(path) with open(path, "r") as fp: model_code = fp.read() while os.path.exists(cached_path) and not recompile: with open(cached_path, "rb") as fp: model = pickle.load(fp) if model.model_code != model_code: if verbose: logging.warn("Cached model at {} differs from the code in {}; "\ "recompiling model".format(cached_path, path)) recompile = True continue else: if verbose: logging.info("Using pre-compiled model from {}".format(cached_path)) break else: model = stan.StanModel(model_code=model_code) # Save the compiled model. if not os.path.exists(cached_path) or overwrite: with open(cached_path, "wb") as fp: pickle.dump(model, fp) return model def sampling_kwds(**kwargs): r""" Prepare a dictionary that can be passed to Stan at the sampling stage. Basically this just prepares the initial positions so that they match the number of chains. """ kwds = dict(chains=4) kwds.update(kwargs) if "init" in kwds: kwds["init"] = [kwds["init"]] * kwds["chains"] return kwds

And now let's write the stan model. Create a file called decay.stan and enter the following code: data { int<lower=1> N_widgets; // number of widgets // Time of manufacture. vector[N_widgets] t_initial; // Time of measurement. vector[N_widgets] t_measured; // Amount of material measured is uncertain. vector[N_widgets] N_measured; vector[N_widgets] sigma_N_measured; // Maximum amount of initial material. real N_initial_max; } parameters { // The decay rate parameter. real<lower=0> alpha; // The amount of initial material is not known. vector<lower=0, upper=N_initial_max>[N_widgets] N_initial; } model { for (i in 1:N_widgets) { N_measured[i] ~ normal( N_initial[i] * exp(-alpha * (t_measured[i] - t_initial[i])), sigma_N_measured[i] ); } }

The code here looks very simple, where things are separated into data, parameters, and model blocks. Each block is self-explanatory, and the main parameterisation occurs where we write N_measured ~ normal(..., ...);. That is to say N_measured is drawn from a normal distribution where we provide the expected mean and the observed uncertainty sigma_N_measured. You can also see that we place a constraint on alpha that it must be positive, and that N_initial (for all widgets) must be between 0 and N_initial_max. That's it.

How do we perform inference with this model? Like this,... import stan_utils as stan model = stan.load_stan_model("example.stan") # Data. data_dict = dict( N_widgets=N_widgets, t_initial=t_initial_true, t_measured=t_measured, N_measured=N_measured, sigma_N_measured=sigma_N_measured, N_initial_max=N_initial_max, ) # Here we will cheat by giving the true value of alpha as the initialisation. init_dict = dict(alpha=alpha) # Run optimisation. opt_stan = model.optimizing( data=data_dict, init=init_dict ) # Run sampling. samples = model.sampling(**stan.sampling_kwds( chains=2, iter=2000, data=data_dict, init=opt_stan ))

By default stan uses the BFGS optimisation algorithm (with analytic derivatives computed for you!) and a dynamic Hamiltonian Monte Carlo sampler where the scale length of the momentum vector is automatically tuned. You should now have samples for all of your parameters (do print(samples) to get a summary), and you can plot the trace and marginalised posterior of individual parameters: fig = samples.traceplot(("alpha", )) # Draw true value for comparison. fig.axes[0].axvline(alpha, c="#666666", ls=":", zorder=-1, lw=1)

That looks pretty good! But we also have posteriors on the initial mass of radioactive material for every single widget! fig = samples.traceplot(("N_initial", ))

Here you can see that the initial mass is well constrained for a few widgets, but for many of them the posteriors are totally uninformative. Note that the posteriors you are seeing here are marginalised over \(\alpha\)! So even though we have 100 widgets, with very uncertain information about all of them, we can make inferences on the hyper-parametersOr population parameters that they all share. That's where hierarchical modelling is most powerful: where you have individual noisy estimates of things that share common information!

Example: school effects

The parallel experiments in schools example is a common one used to introduce hierarchical models.I think this example first appeared in Bayesian Data Analysis but it may have originated elsewhere.

From Bayesian Data Analysis:

A study was performed for the Educational Testing Service to analyse the effects of special coaching programs on test scores. Separate randomised experiments were performed to estimate the effects of coaching programs for the SAT-V (Scholastic Aptitude Test-Verbal) in each of eight high schools. The outcome variable in each study was the score on a special administration of the SAT-V, a standardized multiple choice test administered by the Educational Testing Service and used to help colleges make admissions decisions; the scores can vary between 200 and 800, with mean about 500 and standard devation about 100. The SAT examinations are designed to be resistant to short-term efforts directed specifically toward improving eprformance on the test; instead they are designed to reflect knowledge acquired and abilities developed over manh years of education. Nevertheless, each of the eight schools in this study considered its short-term coaching program to be successful at increasing SAT scores. Also, there was no prior reason to believe that any of the eight programs was more effective than any other or that some were more similar in effect to each other than to any other.

The results of the experiments are summarised in the table below. All students in the experiments had already taken the PSAT (Preliminary SAT), and allowance was made for differences in the PSAT-M (Mathematics) and PSAT-V test scores between coached and uncoached students. In particular, in each school the estimated coaching effect and its standard error were estimatedBy an analysis of covariance adjustment (that is, a linear regression was performed of SAT-V on treatment group, using PSAT-M and PSAT-V as control variables) appropriate for a completely randomised experiment. A separate regression was estimated for each school. Although not simple sample means (because of the covariance adjustments), the estimated coaching effects are labelled \(y_j\), and their sampling variances are \(\sigma_j^2\). The estimates \(y_j\) are obtained by independent experiments and have approximately normal sampling distributions with sampling variances that are known, for all practical purposes, because the sample sizes in all eight experiments were relatively large, over thirty students in each school. Incidentally, an increase of eight points on the SAT-V corresponds to about one more test item correct.

School Estimated
treatment
effect, \(y_j\)
Standard error
of effect
estimate, \(\sigma_j\)
A2815
B810
C-316
D711
E-19
F111
G1810
H1218

How would you evaluate the efficacy of the treatment? How would you model it? What parameters would you introduce? Can you draw your graphical model?

 

← Previous class
Gaussian processes
Next class →
Mixture models

Contributions

The visualisations shown here come from the Distill article A Visual Exploration of Gaussian Processes by Jochen Görtler, Rebecca Kehlbeck, and Oliver Beussen.

The code for the CO2 example comes from the george documentation written by Dan Foreman-Mackey (Flatiron).