Skip to content

Commit

Permalink
add marginal and conditional kernel for gauss
Browse files Browse the repository at this point in the history
  • Loading branch information
mschauer committed Mar 19, 2024
1 parent 4ea2062 commit 04bcdc5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/Mitosis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import StatsBase.sample
include("mt.jl")


export Gaussian, Copy, fuse, weighted
export Gaussian, conditional, marginal, Copy, fuse, weighted
export Traced, BFFG, left′, right′, forward, backward, backwardfilter, forwardsampler
export BF, density, logdensity, , kernel, correct, Kernel, WGaussian, Gaussian, ConstantMap, AffineMap, LinearMap, GaussKernel

Expand Down
15 changes: 14 additions & 1 deletion src/gauss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,24 @@ Base.:*(M, p::Gaussian{P}) where {P} = Gaussian{P}(M * mean(p), Σ = M * cov(p)

"""
conditional(p::Gaussian, A, B, xB)
conditional(p::Gaussian, A, B)
Conditional distribution of `X[i for i in A]` given
`X[i for i in B] == xB` if ``X ~ P``.
`X[i for i in B] == xB` if ``X ~ P``. The version without
the argument `xB` returns a kernel mapping `xB` to the conditional.
"""
function conditional(p::Gaussian{(:μ, :Σ)}, A, B, xB)
Z = p.Σ[A,B]*inv(p.Σ[B,B])
Gaussian{(:μ, :Σ)}(p.μ[A] + Z*(xB - p.μ[B]), p.Σ[A,A] - Z*p.Σ[B,A])
end

function conditional(p::Gaussian{(:μ, :Σ)}, A, B)
Z = p.Σ[A,B]*inv(p.Σ[B,B])
β = p.μ[A] - Z*p.μ[B]
Σ = p.Σ[A,A] - Z*p.Σ[B,A]
kernel(Gaussian; μ=AffineMap(Z, β), Σ=ConstantMap(Σ))
end

function marginal(p::Gaussian{(:μ, :Σ)}, A)
Gaussian{(:μ, :Σ)}(p.μ[A], p.Σ[A,A])
end
12 changes: 12 additions & 0 deletions test/gauss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,16 @@ p3 = convert(Gaussian{(:μ,:Σ)}, p2)
@test 10/sqrt(K) > norm(m - mean(rand(G) for x in 1:K))
@test 10/sqrt(K) > norm(Q - cov([rand(G) for x in 1:K]))
end
end

@testset "conditional" begin
using Mitosis
d = 5
d1 = 2
μ = randn(d)
A = randn(d,d)
q = Gaussian{(:μ,:Σ)}(μ, A*A')
π = marginal(q, 1:d1)
k = conditional(q, d1+1:d, 1:d1)
@test k(π) marginal(q, d1+1:d)
end

0 comments on commit 04bcdc5

Please sign in to comment.