Kemungkinan Marginal dari Output Gibbs

13

Saya mereproduksi dari awal hasil di Bagian 4.2.1 dari

Kemungkinan Marginal dari Output Gibbs

Siddhartha Chib

Jurnal Asosiasi Statistik Amerika, Vol. 90, No. 432. (Desember, 1995), hlm. 1313-1321.

Ini adalah campuran dari model normals dengan nomor yang diketahui komponen. f ( x | w , μ , σ 2 ) = n Π i = 1 k Σ j = 1 N ( x i | μk1

f(xw,μ,σ2)=i=1nj=1kN(xiμj,σj2).()

Sampler Gibbs untuk model ini diimplementasikan menggunakan teknik augmentasi data Tanner dan Wong. Seperangkat variabel alokasi dengan asumsi nilai diperkenalkan, dan kami menetapkan bahwa dan f (x_i \ mid z , \ mu, \ sigma ^ 2) = \ mathrm {N} (x_i \ mid \ mu_ {z_i}, \ sigma ^ 2_ {z_i}) . Oleh karena itu, integrasi atas z_i memberikan kemungkinan asli (*) .z=(z1,,zn)1,,kPr(zi=jw)=wjf(xiz,μ,σ2)=N(xiμzi,σzi2)zi()

Dataset dibentuk oleh kecepatan 82 galaksi dari konstelasi Corona Borealis.

set.seed(1701)

x <- c(  9.172,  9.350,  9.483,  9.558,  9.775, 10.227, 10.406, 16.084, 16.170, 18.419, 18.552, 18.600, 18.927,
        19.052, 19.070, 19.330, 19.343, 19.349, 19.440, 19.473, 19.529, 19.541, 19.547, 19.663, 19.846, 19.856,
        19.863, 19.914, 19.918, 19.973, 19.989, 20.166, 20.175, 20.179, 20.196, 20.215, 20.221, 20.415, 20.629,
        20.795, 20.821, 20.846, 20.875, 20.986, 21.137, 21.492, 21.701, 21.814, 21.921, 21.960, 22.185, 22.209,
        22.242, 22.249, 22.314, 22.374, 22.495, 22.746, 22.747, 22.888, 22.914, 23.206, 23.241, 23.263, 23.484,
        23.538, 23.542, 23.666, 23.706, 23.711, 24.129, 24.285, 24.289, 24.366, 24.717, 24.990, 25.633, 26.960,
        26.995, 32.065, 32.789, 34.279 )

nn <- length(x)

Kami menganggap bahwa , 's, dan adalah independen apriori dengan wμjσj2

(w1,,wk)Dir(a1,,ak),μjN(μ0,σ02),σj2IG(ν02,δ02).
k <- 3

mu0 <- 20
va0 <- 100

nu0 <- 6
de0 <- 40

a <- rep(1, k)

Menggunakan Teorema Bayes, syarat penuhnya adalah di mana dengan

wμ,σ2,z,xDir(a1+n1,,ak+nk)μjw,σ2,z,xN(njmjσ02+μ0σj2njσ02+σj2,σ02σj2njσ02+σj2)σj2w,μ,z,xIG(ν0+nj2,δ0+δj2)Pr(zi=jw,μ,σ2,x)wj×1σje(xiμj)2/2σj2
nj=|Lj|,mj={1njiLjxiifnj>00otherwise.,δj=iLj(xiμj)2,
Lj={i{1,,n}:zi=j} .

Tujuannya adalah untuk menghitung estimasi kemungkinan marginal dari model. Metode Chib dimulai dengan menjalankan sampler Gibbs pertama menggunakan syarat penuh.

burn_in <- 1000
run     <- 15000

cat("First Gibbs run (full):\n")

N <- burn_in + run

w  <- matrix(1, nrow = N, ncol = k)
mu <- matrix(0, nrow = N, ncol = k)
va <- matrix(1, nrow = N, ncol = k)
z  <- matrix(1, nrow = N, ncol = nn)

n <- integer(k)
m <- numeric(k)
de <- numeric(k)

rdirichlet <- function(a) { y <- rgamma(length(a), a, 1); y / sum(y) }

pb <- txtProgressBar(min = 2, max = N, style = 3)
z[1,] <- sample.int(k, size = nn, replace = TRUE)
for (t in 2:N) {
    n <- tabulate(z[t-1,], nbins = k)
    w[t,] <- rdirichlet(a + n)
    m <- sapply(1:k, function(j) sum(x[z[t-1,]==j]))
    m[n > 0] <- m[n > 0] / n[n > 0]
    mu[t,] <- rnorm(k, mean = (n*m*va0+mu0*va[t-1,])/(n*va0+va[t-1,]), sd = sqrt(va0*va[t-1,]/(n*va0+va[t-1,])))
    de <- sapply(1:k, function(j) sum((x[z[t-1,]==j] - mu[t,j])^2))
    va[t,] <- 1 / rgamma(k, shape = (nu0+n)/2, rate = (de0+de)/2)
    z[t,] <- sapply(1:nn, function(i) sample.int(k, size = 1, prob = exp(log(w[t,]) + dnorm(x[i], mean = mu[t,], sd = sqrt(va[t,]), log = TRUE))))
    setTxtProgressBar(pb, t)
}
close(pb)

Dari proses pertama ini, kami mendapatkan titik perkiraan dari kemungkinan maksimum. Karena kemungkinan sebenarnya tidak terikat, apa yang mungkin diberikan prosedur ini adalah perkiraan MAP lokal.(w,μ,σ2)

w  <- w[(burn_in+1):N,]
mu <- mu[(burn_in+1):N,]
va <- va[(burn_in+1):N,]
z  <- z[(burn_in+1):N,]
N  <- N - burn_in

log_L <- function(x, w, mu, va) sum(log(sapply(1:nn, function(i) sum(exp(log(w) + dnorm(x[i], mean = mu, sd = sqrt(va), log = TRUE))))))

ts <- which.max(sapply(1:N, function(t) log_L(x, w[t,], mu[t,], va[t,])))

ws <- w[ts,]
mus <- mu[ts,]
vas <- va[ts,]

Perkiraan log Chib tentang kemungkinan marginal adalah

logf(x)^=logLx(w,μ,σ2)+logπ(w,μ,σ2)logπ(μx)logπ(σ2μ,x)logπ(wμ,σ2,x).

Kami sudah memiliki dua istilah pertama.

log_prior <- function(w, mu, va) {
    lgamma(sum(a)) - sum(lgamma(a)) + sum((a-1)*log(w))
    + sum(dnorm(mu, mean = mu0, sd = sqrt(va0), log = TRUE))
    + sum((nu0/2)*log(de0/2) - lgamma(nu0/2) - (nu0/2+1)*log(va) - de0/(2*va))
}

chib <- log_L(x, ws, mus, vas) + log_prior(ws, mus, vas)

Estimasi Rao-Blackwellized dari adalah dan siap diperoleh dari jangka Gibbs pertama.π(μx)

π(μx)=j=1kN(μj|njmjσ02+μ0σj2njσ02+σj2,σ02σj2njσ02+σj2)p(σ2,zx)dσ2dz,
pi.mu_va.z.x <- function(mu, va, z) {
    n <- tabulate(z, nbins = k)
    m <- sapply(1:k, function(j) sum(x[z==j]))
    m[n > 0] <- m[n > 0] / n[n > 0]
    exp(sum(dnorm(mu, mean = (n*m*va0+mu0*va)/(n*va0+va), sd = sqrt(va0*va/(n*va0+va)), log = TRUE)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.mu_va.z.x(mus, va[t,], z[t,]))))

Estimasi Rao-Blackwellized dari adalah dan dihitung dari penurunan Gibbs kedua yang dijalankan di mana tidak diperbarui, tetapi dibuat sama dengan pada setiap langkah iterasi.π(σ2μ,x)

π(σ2μ,x)=j=1kIG(σj2|ν0+nj2,δ0+δj2)p(zμ,x)dz,
μjμj
cat("Second Gibbs run (reduced):\n")

N <- burn_in + run

w  <- matrix(1, nrow = N, ncol = k)
va <- matrix(1, nrow = N, ncol = k)
z  <- matrix(1, nrow = N, ncol = nn) 

pb <- txtProgressBar(min = 2, max = N, style = 3)
z[1,] <- sample.int(k, size = nn, replace = TRUE)
for (t in 2:N) {
    n <- tabulate(z[t-1,], nbins = k)
    w[t,] <- rdirichlet(a + n)
    de <- sapply(1:k, function(j) sum((x[z[t-1,]==j] - mus[j])^2))
    va[t,] <- 1 / rgamma(k, shape = (nu0+n)/2, rate = (de0+de)/2)
    z[t,] <- sapply(1:nn, function(i) sample.int(k, size = 1, prob = exp(log(w[t,]) + dnorm(x[i], mean = mus, sd = sqrt(va[t,]), log = TRUE))))
    setTxtProgressBar(pb, t)
}
close(pb)

w  <- w[(burn_in+1):N,]
va <- va[(burn_in+1):N,]
z  <- z[(burn_in+1):N,]
N  <- N - burn_in

pi.va_mu.z.x <- function(va, mu, z) {
    n <- tabulate(z, nbins = k)         
    de <- sapply(1:k, function(j) sum((x[z==j] - mu[j])^2))
    exp(sum(((nu0+n)/2)*log((de0+de)/2) - lgamma((nu0+n)/2) - ((nu0+n)/2+1)*log(va) - (de0+de)/(2*va)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.va_mu.z.x(vas, mus, z[t,]))))

Dengan cara yang sama, estimasi Rao-Blackwellized dari adalah dan dihitung dari dijalankannya Gibbs ketiga yang dikurangi di mana dan tidak diperbarui, tetapi dibuat sama dengan dan masing-masing pada setiap langkah iterasi.π(wμ,σ2,x)

π(wμ,σ2,x)=Dir(wa1+n1,,ak+nk)p(zμ,σ2,x)dz,
μjσj2μjσj2
cat("Third Gibbs run (reduced):\n")

N <- burn_in + run

w  <- matrix(1, nrow = N, ncol = k)
z  <- matrix(1, nrow = N, ncol = nn) 

pb <- txtProgressBar(min = 2, max = N, style = 3)
z[1,] <- sample.int(k, size = nn, replace = TRUE)
for (t in 2:N) {
    n <- tabulate(z[t-1,], nbins = k)
    w[t,] <- rdirichlet(a + n)
    z[t,] <- sapply(1:nn, function(i) sample.int(k, size = 1, prob = exp(log(w[t,]) + dnorm(x[i], mean = mus, sd = sqrt(vas), log = TRUE))))
    setTxtProgressBar(pb, t)
}
close(pb)

w  <- w[(burn_in+1):N,]
z  <- z[(burn_in+1):N,]
N  <- N - burn_in

pi.w_z.x <- function(w, z) {
    n <- tabulate(z, nbins = k)
    exp(lgamma(sum(a+n)) - sum(lgamma(a+n)) + sum((a+n-1)*log(w)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.w_z.x(ws, z[t,]))))

Setelah semua ini, kami mendapatkan estimasi log yang lebih besar dari yang dilaporkan oleh Chib: dengan kesalahan Monte Carlo .217.9199224.138.086

Untuk memeriksa apakah saya entah bagaimana mengacaukan Gibbs samplers, saya menerapkan semuanya menggunakan RJAGS. Kode berikut memberikan hasil yang sama.

x <- c( 9.172,  9.350,  9.483,  9.558,  9.775, 10.227, 10.406, 16.084, 16.170, 18.419, 18.552, 18.600, 18.927, 19.052, 19.070, 19.330,
       19.343, 19.349, 19.440, 19.473, 19.529, 19.541, 19.547, 19.663, 19.846, 19.856, 19.863, 19.914, 19.918, 19.973, 19.989, 20.166,
       20.175, 20.179, 20.196, 20.215, 20.221, 20.415, 20.629, 20.795, 20.821, 20.846, 20.875, 20.986, 21.137, 21.492, 21.701, 21.814,
       21.921, 21.960, 22.185, 22.209, 22.242, 22.249, 22.314, 22.374, 22.495, 22.746, 22.747, 22.888, 22.914, 23.206, 23.241, 23.263,
       23.484, 23.538, 23.542, 23.666, 23.706, 23.711, 24.129, 24.285, 24.289, 24.366, 24.717, 24.990, 25.633, 26.960, 26.995, 32.065,
       32.789, 34.279 )

library(rjags)

nn <- length(x)

k <- 3

mu0 <- 20
va0 <- 100

nu0 <- 6
de0 <- 40

a <- rep(1, k)

burn_in <- 10^3

N <- 10^4

full <- "
    model {
        for (i in 1:n) {
            x[i] ~ dnorm(mu[z[i]], tau[z[i]])
            z[i] ~ dcat(w[])
        }
        for (i in 1:k) {
            mu[i] ~ dnorm(mu0, 1/va0)
            tau[i] ~ dgamma(nu0/2, de0/2)
            va[i] <- 1/tau[i]
        }
        w ~ ddirich(a)
    }
"
data <- list(x = x, n = nn, k = k, mu0 = mu0, va0 = va0, nu0 = nu0, de0 = de0, a = a)
model <- jags.model(textConnection(full), data = data, n.chains = 1, n.adapt = 100)
update(model, n.iter = burn_in)
samples <- jags.samples(model, c("mu", "va", "w", "z"), n.iter = N)

mu <- matrix(samples$mu, nrow = N, byrow = TRUE)
    va <- matrix(samples$va, nrow = N, byrow = TRUE)
w <- matrix(samples$w, nrow = N, byrow = TRUE)
    z <- matrix(samples$z, nrow = N, byrow = TRUE)

log_L <- function(x, w, mu, va) sum(log(sapply(1:nn, function(i) sum(exp(log(w) + dnorm(x[i], mean = mu, sd = sqrt(va), log = TRUE))))))

ts <- which.max(sapply(1:N, function(t) log_L(x, w[t,], mu[t,], va[t,])))

ws <- w[ts,]
mus <- mu[ts,]
vas <- va[ts,]

log_prior <- function(w, mu, va) {
    lgamma(sum(a)) - sum(lgamma(a)) + sum((a-1)*log(w))
    + sum(dnorm(mu, mean = mu0, sd = sqrt(va0), log = TRUE))
    + sum((nu0/2)*log(de0/2) - lgamma(nu0/2) - (nu0/2+1)*log(va) - de0/(2*va))
}

chib <- log_L(x, ws, mus, vas) + log_prior(ws, mus, vas)

cat("log-likelihood + log-prior =", chib, "\n")

pi.mu_va.z.x <- function(mu, va, z, x) {
    n <- sapply(1:k, function(j) sum(z==j))
    m <- sapply(1:k, function(j) sum(x[z==j]))
    m[n > 0] <- m[n > 0] / n[n > 0]
    exp(sum(dnorm(mu, mean = (n*m*va0+mu0*va)/(n*va0+va), sd = sqrt(va0*va/(n*va0+va)), log = TRUE)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.mu_va.z.x(mus, va[t,], z[t,], x))))

cat("log-likelihood + log-prior - log-pi.mu_ =", chib, "\n")

fixed.mu <- "
    model {
        for (i in 1:n) {
            x[i] ~ dnorm(mus[z[i]], tau[z[i]])
            z[i] ~ dcat(w[])
        }
        for (i in 1:k) {
            tau[i] ~ dgamma(nu0/2, de0/2)
            va[i] <- 1/tau[i]
        }
        w ~ ddirich(a)
    }
"
data <- list(x = x, n = nn, k = k, nu0 = nu0, de0 = de0, a = a, mus = mus)
model <- jags.model(textConnection(fixed.mu), data = data, n.chains = 1, n.adapt = 100)
update(model, n.iter = burn_in)
samples <- jags.samples(model, c("va", "w", "z"), n.iter = N)

va <- matrix(samples$va, nrow = N, byrow = TRUE)
    w <- matrix(samples$w, nrow = N, byrow = TRUE)
z <- matrix(samples$z, nrow = N, byrow = TRUE)

pi.va_mu.z.x <- function(va, mu, z, x) {
    n <- sapply(1:k, function(j) sum(z==j))
    de <- sapply(1:k, function(j) sum((x[z==j] - mu[j])^2))
    exp(sum(((nu0+n)/2)*log((de0+de)/2) - lgamma((nu0+n)/2) - ((nu0+n)/2+1)*log(va) - (de0+de)/(2*va)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.va_mu.z.x(vas, mus, z[t,], x))))

cat("log-likelihood + log-prior - log-pi.mu_ - log-pi.va_ =", chib, "\n")

fixed.mu.and.va <- "
    model {
        for (i in 1:n) {
            x[i] ~ dnorm(mus[z[i]], 1/vas[z[i]])
            z[i] ~ dcat(w[])
        }
        w ~ ddirich(a)
    }
"
data <- list(x = x, n = nn, a = a, mus = mus, vas = vas)
model <- jags.model(textConnection(fixed.mu.and.va), data = data, n.chains = 1, n.adapt = 100)
update(model, n.iter = burn_in)
samples <- jags.samples(model, c("w", "z"), n.iter = N)

w <- matrix(samples$w, nrow = N, byrow = TRUE)
    z <- matrix(samples$z, nrow = N, byrow = TRUE)

pi.w_z.x <- function(w, z, x) {
    n <- sapply(1:k, function(j) sum(z==j))
    exp(lgamma(sum(a)+nn) - sum(lgamma(a+n)) + sum((a+n-1)*log(w)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.w_z.x(ws, z[t,], x))))

cat("log-likelihood + log-prior - log-pi.mu_ - log-pi.va_ - log-pi.w_ =", chib, "\n")

Pertanyaan saya adalah apakah dalam uraian di atas ada kesalahpahaman tentang metode Chib atau kesalahan dalam implementasinya.

Zen
sumber
1
Menjalankan simulasi 100 kali, hasilnya berada dalam kisaran . [218.7655;216.8824]
Zen

Jawaban:

6

Ada sedikit kesalahan pemrograman pada bagian sebelumnya

log_prior <- function(w, mu, va) {
    lgamma(sum(a)) - sum(lgamma(a)) + sum((a-1)*log(w))
    + sum(dnorm(mu, mean = mu0, sd = sqrt(va0), log = TRUE))
    + sum((nu0/2)*log(de0/2) - lgamma(nu0/2) - (nu0/2+1)*log(va) - de0/(2*va))
}

sebagaimana mestinya sebagai gantinya

log_prior <- function(w, mu, va) {
    lgamma(sum(a)) - sum(lgamma(a)) + sum((a-1)*log(w)) +
      sum(dnorm(mu, mean = mu0, sd = sqrt(va0), log = TRUE)) +
      sum((nu0/2)*log(de0/2) - lgamma(nu0/2) - (nu0/2+1)*log(va) - de0/(2*va))
}

Menjalankan kembali kode yang mengarah ke sini

> chib
[1] -228.194

yang bukan nilai yang dihasilkan dalam Chib (1995) untuk kasus itu! Namun, dalam Neal (1999) menganalisis kembali masalah, ia menyebutkan itu

Menurut salah satu wasit JASA anonim, angka -224.138 untuk log kemungkinan marginal untuk model tiga komponen dengan varian yang tidak sama yang diberikan dalam kertas Chib adalah "kesalahan ketik" dengan angka yang benar adalah -228.608.

Jadi ini memecahkan masalah perbedaan.

Xi'an
sumber
2
Christian Robert dan Kate Lee: apakah Anda tahu betapa hebatnya Anda?
Zen
2
Omong-omong, ini jelas merupakan contoh "sintaks jahat". Saya tidak akan melupakan yang ini.
Zen