Skip to contents

The stage package implements the STAGE model (State Transition Analysis via Generative Estimator), a Bayesian generative classifier for estimating transition points (m50) from two groups—at what value of xx does yy become probable. An example is length-at-maturity problems, where yy is binary maturity status, and xx is a continuous measurement, such as length.

Existing methods for these sorts of applications have important limitations (see below), which the STAGE model overcomes.

Features of the STAGE model are:

  • Each class is fit with a mixture distribution that combines a uniform plateau with a Gaussian tail
  • Estimation of the transition point (m50) is based on the relative densities (Bayes’ rule) of the classes given xx, not discriminative regression
  • A shared transition region centred at m50 with width d.

The goal is a model transition-point estimation that focuses inference on data in and around the transition point while being robust to unbalanced sample sizes.

The process of testing the STAGE model against other methods is underway.


Package features

Planned additions:

  • Reparameterised logistic/HOF model
  • Bayesian LDA
  • Posterior predictive classification
  • Diagnostic + visualisation tools
  • Full vignette

Installation

You need cmdstanr + CmdStan installed:

install.packages("cmdstanr",
  repos = c("https://mc-stan.org/r-packages/", getOption("repos"))
)
cmdstanr::install_cmdstan()   # run once

Then install from GitHub:

devtools::install_github("anhsmith/stage")

Example

library(stage)

set.seed(123)

N <- 200
L <- 1000
U <- 1500
true_m50 <- 1250
true_d   <- 100

# simulate data
x <- runif(N, L, U)
p <- plogis((x - true_m50) / (true_d / 4))
y <- rbinom(N, 1, p)

# fit STAGE model
fit <- fit_stage(x, y, L = L, U = U, chains = 2, iter = 1000)
fit

Estimated transition point:

Prediction:

x_grid <- seq(L, U, length.out = 100)
p_hat <- predict(fit, x_grid, type = "prob")

plot(x_grid, p_hat, type = "l",
     xlab = "x", ylab = "P(mature)", las = 1)

Hierarchical model

Multiple populations:

group <- rep(c("A","B"), each = N/2)

fit2 <- fit_stage(x, y, group = group, L = L, U = U)

transition_point(fit2)

This returns:

  • global transition point m50
  • population-specific transition points m50_pop[j]

Model summary

Each class is a uniform plateau with a Gaussian edge. Key parameters:

  • m50 — transition point where P(y=1|x) = 0.5 (direct sampling parameter)
  • d — width of transition region
  • sigma_x — Gaussian spread

Hierarchical structure:

m50_pop[j] = m50 + z[j] * sigma_alpha

(non-centred parameterisation for efficient sampling)


Development

devtools::test()
devtools::document()
devtools::build_readme()

License

MIT © Adam N. H. Smith