Menerapkan inferensi variatif stokastik untuk Bayesian Mixture of Gaussian

9

Saya mencoba untuk menerapkan model Campuran Gaussian dengan inferensi variasional stokastik, berikut ini kertas .

masukkan deskripsi gambar di sini

Ini adalah pgm dari Campuran Gaussian.

Menurut makalah itu, algoritma penuh inferensi variatif stokastik adalah: masukkan deskripsi gambar di sini

Dan saya masih sangat bingung dengan metode untuk menskalakannya menjadi GMM.

Pertama, saya pikir parameter variasional lokal hanya dan yang lainnya adalah semua parameter global. Harap perbaiki saya jika saya salah. Apa yang dimaksud dengan langkah 6 ? Apa yang harus saya lakukan untuk mencapai ini?qzas though Xi is replicated by N times

Bisakah Anda membantu saya dengan ini? Terima kasih sebelumnya!

pengguna5779223
sumber
Dikatakan bahwa alih-alih menggunakan seluruh dataset, sampel satu datapoint dan berpura-puralah Anda memiliki datapoint dengan ukuran yang sama. Dalam banyak kasus, ini akan setara dengan mengalikan harapan dengan satu datapoint oleh . NNN
Daeyoung Lim
@ DavideyLim Terima kasih atas balasan Anda! Saya mengerti maksud Anda sekarang, tetapi saya masih bingung bahwa statistik mana yang harus diperbarui secara lokal dan mana yang harus diperbarui secara global. Sebagai contoh, di sini adalah implementasi dari campuran Gaussian, bisakah Anda memberi tahu saya bagaimana skala untuk svi? Saya sedikit tersesat. Terima kasih banyak!
user5779223
Saya tidak membaca seluruh kode tetapi jika Anda berurusan dengan model campuran Gaussian, variabel indikator komponen campuran harus menjadi variabel lokal karena masing-masing terkait dengan hanya satu pengamatan. Jadi variabel laten komponen campuran yang mengikuti distribusi Multinoulli (juga dikenal sebagai distribusi Kategorikal dalam ML) adalah dalam uraian Anda di atas. zi,i=1,,N
Daeyoung Lim
@ DavideyLim Ya, saya mengerti apa yang Anda katakan sejauh ini. Jadi untuk distribusi variasional q (Z) q (\ pi, \ mu, \ lambda), q (Z) harus merupakan variabel lokal. Tetapi ada banyak parameter yang terkait dengan q (Z). Di sisi lain, ada juga banyak parameter yang terkait dengan q (\ pi, \ mu, \ lambda). Dan saya tidak tahu cara memperbaruinya dengan tepat.
user5779223
Anda harus menggunakan asumsi bidang-rata untuk mendapatkan distribusi variasi yang optimal untuk parameter variasi. Berikut rujukannya: maths.usyd.edu.au/u/jormerod/JTOpapers/Ormerod10.pdf
Daeyoung Lim

Jawaban:

1

Pertama, beberapa catatan yang membantu saya memahami makalah SVI:

  • Dalam menghitung nilai tengah untuk parameter variasional dari parameter global, kami mengambil sampel satu titik data dan berpura-pura seluruh kumpulan data kami ukuran adalah titik tunggal, kali.NNN
  • βηg adalah parameter alami untuk kondisi penuh variabel global . Notasi digunakan untuk menekankan bahwa ini adalah fungsi dari variabel terkondisi, termasuk data yang diamati. β

Dalam campuran Gaussians, parameter global kami adalah parameter rata-rata dan presisi (varian terbalik) params untuk masing-masing. Yaitu, adalah parameter alami untuk distribusi ini, sebuah Normal-Gamma dari formulirμ k , τ k η gkμk,τkηg

μ,τN(μ|γ,τ(2α1)Ga(τ|α,β)

dengan , dan . (Bernardo dan Smith, Bayesian Theory ; perhatikan ini sedikit berbeda dari empat-parameter Normal-Gamma yang biasa Anda lihat .) Kita akan menggunakan untuk merujuk pada parameter variasi untukη0=2α1η1=γ(2α1)η2=2β+γ2(2α1)a,b,mα,β,μ

penuh dari adalah Normal-Gamma dengan params , , , di mana adalah yang sebelumnya. (The di sana juga bisa membingungkan; masuk akal dimulai dengan trik diterapkan pada , dan diakhiri dengan jumlah aljabar yang tersisa untuk pembaca.)˙ n + Σ N zμk,τkη˙+Nzn,kNzn,kxNNzn,kxn2η˙zn,kexpln(p))Np(xn|zn,α,β,γ)=NK(p(xn|αk,βk,γk))zn,k

Dengan itu, kita dapat menyelesaikan langkah (5) dari pseudocode SVI dengan:

ϕn,kexp(ln(π)+Eqln(p(xn|αk,βk,γk))=exp(ln(π)+Eq[μkτk,τ2x,x2μ2τlnτ2)]

Memperbarui parameter global lebih mudah, karena setiap parameter terkait dengan jumlah data atau salah satu statistik yang memadai:

λ^=η˙+Nϕn1,x,x2

Inilah kemungkinan marginal dari data terlihat pada banyak iterasi, ketika dilatih tentang data yang sangat tiruan, mudah dipisahkan (kode di bawah). Plot pertama menunjukkan kemungkinan dengan inisialisasi, parameter variasional acak dan iterasi; masing-masing berikutnya adalah setelah kekuatan dua iterasi berikutnya. Dalam kode, merujuk ke parameter variasi untuk .a , b , m α , β , μ0a,b,mα,β,μ

masukkan deskripsi gambar di sini

masukkan deskripsi gambar di sini

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Aug 12 12:49:15 2018

@author: SeanEaster
"""

import numpy as np
from matplotlib import pylab as plt
from scipy.stats import t
from scipy.special import digamma 

# These are priors for mu, alpha and beta

def calc_rho(t, delay=16,forgetting=1.):
    return np.power(t + delay, -forgetting)

m_prior, alpha_prior, beta_prior = 0., 1., 1.
eta_0 = 2 * alpha_prior - 1
eta_1 = m_prior * (2 * alpha_prior - 1)
eta_2 = 2 *  beta_prior + np.power(m_prior, 2.) * (2 * alpha_prior - 1)

k = 3

eta_shape = (k,3)
eta_prior = np.ones(eta_shape)
eta_prior[:,0] = eta_0
eta_prior[:,1] = eta_1
eta_prior[:,2] = eta_2

np.random.seed(123) 
size = 1000
dummy_data = np.concatenate((
        np.random.normal(-1., scale=.25, size=size),
        np.random.normal(0.,  scale=.25,size=size),
        np.random.normal(1., scale=.25, size=size)
        ))
N = len(dummy_data)
S = 1

# randomly init global params
alpha = np.random.gamma(3., scale=1./3., size=k)
m = np.random.normal(scale=1, size=k)
beta = np.random.gamma(3., scale=1./3., size=k)

eta = np.zeros(eta_shape)
eta[:,0] = 2 * alpha - 1
eta[:,1] = m * eta[:,0]
eta[:,2] = 2. * beta + np.power(m, 2.) * eta[:,0]


phi = np.random.dirichlet(np.ones(k) / k, size = dummy_data.shape[0])

nrows, ncols = 4, 5
total_plots = nrows * ncols
total_iters = np.power(2, total_plots - 1)
iter_idx = 0

x = np.linspace(dummy_data.min(), dummy_data.max(), num=200)

while iter_idx < total_iters:

    if np.log2(iter_idx + 1) % 1 == 0:

        alpha = 0.5 * (eta[:,0] + 1)
        beta = 0.5 * (eta[:,2] - np.power(eta[:,1], 2.) / eta[:,0])
        m = eta[:,1] / eta[:,0]
        idx = int(np.log2(iter_idx + 1)) + 1

        f = plt.subplot(nrows, ncols, idx)
        s = np.zeros(x.shape)
        for _ in range(k):
            y = t.pdf(x, alpha[_], m[_], 2 * beta[_] / (2 * alpha[_] - 1))
            s += y
            plt.plot(x, y)
        plt.plot(x, s)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)

    # randomly sample data point, update parameters
    interm_eta = np.zeros(eta_shape)
    for _ in range(S):
        datum = np.random.choice(dummy_data, 1)

        # mean params for ease of calculating expectations
        alpha = 0.5 * ( eta[:,0] + 1)
        beta = 0.5 * (eta[:,2] - np.power(eta[:,1], 2) / eta[:,0])
        m = eta[:,1] / eta[:,0]

        exp_mu = m
        exp_tau = alpha / beta 
        exp_tau_m_sq = 1. / (2 * alpha - 1) + np.power(m, 2.) * alpha / beta
        exp_log_tau = digamma(alpha) - np.log(beta)


        like_term = datum * (exp_mu * exp_tau) - np.power(datum, 2.) * exp_tau / 2 \
            - (0.5 * exp_tau_m_sq - 0.5 * exp_log_tau)
        log_phi = np.log(1. / k) + like_term
        phi = np.exp(log_phi)
        phi = phi / phi.sum()

        interm_eta[:, 0] += phi
        interm_eta[:, 1] += phi * datum
        interm_eta[:, 2] += phi * np.power(datum, 2.)

    interm_eta = interm_eta * N / S
    interm_eta += eta_prior

    rho = calc_rho(iter_idx + 1)

    eta = (1 - rho) * eta + rho * interm_eta

    iter_idx += 1
Sean Easter
sumber