Skip to content

Commit

Permalink
FunMC: Add the AIS kernel for use with SMC.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721109723
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Jan 29, 2025
1 parent 55d191e commit 0a56cac
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 2 deletions.
193 changes: 193 additions & 0 deletions spinoffs/fun_mc/fun_mc/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Callable, Generic, Protocol, TypeVar, runtime_checkable

from fun_mc import backend
from fun_mc import fun_mc_lib as fun_mc
from fun_mc import types

jax = backend.jax
Expand All @@ -34,11 +35,16 @@
BoolScalar = types.BoolScalar
IntScalar = types.IntScalar
FloatScalar = types.FloatScalar
PotentialFn = types.PotentialFn

State = TypeVar('State')
Extra = TypeVar('Extra')
KernelExtra = TypeVar('KernelExtra')
T = TypeVar('T')

__all__ = [
'annealed_importance_sampling_kernel',
'AnnealedImportanceSamplingKernelExtra',
'conditional_systematic_resampling',
'effective_sample_size_predicate',
'ParticleGatherFn',
Expand Down Expand Up @@ -518,6 +524,193 @@ def dont_resample(
return smc_state, smc_extra


@runtime_checkable
class AnnealedImportanceSamplingMCMCKernel(Protocol[State, Extra, KernelExtra]):
"""Function that decides whether to resample."""

def __call__(
self,
state: State,
step: IntScalar,
target_log_prob_fn: PotentialFn[Extra],
seed: Seed,
) -> tuple[State, KernelExtra]:
"""Return boolean indicating whether to resample.
Note that resampling happens before stepping the kernel.
Args:
state: State step `t`.
step: The timestep, `t`.
target_log_prob_fn: Target distribution corresponding to `t`.
seed: PRNG seed.
Returns:
new_state: New state, targeting `target_log_prob_fn`.
extra: Extra information from the kernel.
"""


@util.dataclass
class AnnealedImportanceSamplingKernelExtra(Generic[KernelExtra, Extra]):
"""Extra outputs from the AIS kernel.
Attributes:
kernel_extra: Extra outputs from the inner kernel.
next_state_extra: Extra output from the next step's target log prob
function.
cur_state_extra: Extra output from the current step's target log prob
function.
"""

kernel_extra: KernelExtra
cur_state_extra: Extra
next_state_extra: Extra


@types.runtime_typed
def annealed_importance_sampling_kernel(
state: State,
step: IntScalar,
seed: Seed,
kernel: AnnealedImportanceSamplingMCMCKernel[State, Extra, KernelExtra],
make_target_log_probability_fn: Callable[[IntScalar], PotentialFn[Extra]],
) -> tuple[
State,
tuple[
Float[Array, 'num_particles'],
AnnealedImportanceSamplingKernelExtra[KernelExtra, Extra],
],
]:
"""SMC kernel that implements Annealed Importance Sampling.
Annealed Importance Sampling (AIS)[1] can be interpreted as a special case of
SMC with a particular choice of forward and reverse kernels:
```none
r_t = k_t(x_{t + 1} | x_t) p_t(x_t) / p_t(x_{t + 1})
q_t = k_{t - 1}(x_t | x_{t - 1})
```
where `k_t` is an MCMC kernel that has `p_t` invariant. This causes the
incremental weight equation to be particularly simple:
```none
iw_t = p_t(x_t) / p_{t - 1}(x_t)
```
Unfortunately, the reverse kernel is not optimal, so the annealing schedule
needs to be fine. The original formulation from [1] does not do resampling,
but enabling it will usually reduce the variance of the estimator.
Args:
state: The previous particle state, `x_{t - 1}^{1:K}`.
step: The previous timestep, `t - 1`.
seed: PRNG seed.
kernel: The inner MCMC kernel. It takes the current state, the timestep, the
target distribution and the seed and generates an approximate sample from
`p_t` where `t` is the passed-in timestep.
make_target_log_probability_fn: A function that, given a timestep, returns
the target distribution `p_t` where `t` is the passed-in timestep.
Returns:
state: The new particles, `x_t^{1:K}`.
extra: A 2-tuple of:
incremental_log_weights: The incremental log weight at timestep t,
`iw_t^{1:K}`.
kernel_extra: Extra information returned by the kernel.
#### Example
In this example we estimate the normalizing constant ratio between `tlp_1`
and `tlp_2`.
```python
def tlp_1(x):
return -(x**2) / 2.0, ()
def tlp_2(x):
return -((x - 2) ** 2) / 2 / 16.0, ()
@jax.jit
def kernel(smc_state, seed):
smc_seed, seed = jax.random.split(seed, 2)
def inner_kernel(state, stage, tlp_fn, seed):
f = jnp.array(stage, state.dtype) / num_steps
hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn)
hmc_state, _ = fun_mc.hamiltonian_monte_carlo_step(
hmc_state,
tlp_fn,
step_size=f * 4.0 + (1.0 - f) * 1.0,
num_integrator_steps=1,
seed=seed,
)
return hmc_state.state, ()
smc_state, _ = smc.sequential_monte_carlo_step(
smc_state,
kernel=functools.partial(
smc.annealed_importance_sampling_kernel,
kernel=inner_kernel,
make_target_log_probability_fn=functools.partial(
fun_mc.geometric_annealing_path,
num_stages=num_steps,
initial_target_log_prob_fn=tlp_1,
final_target_log_prob_fn=tlp_2,
),
),
seed=smc_seed,
)
return (smc_state, seed), ()
num_steps = 100
num_particles = 400
init_seed, seed = jax.random.split(jax.random.PRNGKey(0))
init_state = jax.random.normal(init_seed, [num_particles])
(smc_state, _), _ = fun_mc.trace(
(
smc.sequential_monte_carlo_init(
init_state,
weight_dtype=self._dtype,
),
smc_seed,
),
kernel,
num_steps,
)
weights = jnp.exp(smc_state.log_weights)
# Should be close to 4.
print(estimated z2/z1, weights.mean())
# Should be close to 2.
print(estimated mean, (jax.nn.softmax(smc_state.log_weights)
* smc_state.state).sum())
```
#### References
[1]: Neal, Radford M. (1998) Annealed Importance Sampling.
https://arxiv.org/abs/physics/9803008
"""
new_state, kernel_extra = kernel(
state, step, make_target_log_probability_fn(step), seed
)
tlp_num, num_extra = fun_mc.call_potential_fn(
make_target_log_probability_fn(step + 1), new_state
)
tlp_denom, denom_extra = fun_mc.call_potential_fn(
make_target_log_probability_fn(step), new_state
)
extra = AnnealedImportanceSamplingKernelExtra(
kernel_extra=kernel_extra,
cur_state_extra=denom_extra,
next_state_extra=num_extra,
)
return new_state, (
tlp_num - tlp_denom,
extra,
)


def _smart_cond(
pred: BoolScalar,
true_fn: Callable[..., T],
Expand Down
67 changes: 67 additions & 0 deletions spinoffs/fun_mc/fun_mc/smc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,73 @@ def kernel(smc_state, seed):
self.assertAllClose(gt_log_evidence, log_evidence, rtol=0.01)
self.assertAllClose(gt_log_evidence, log_evidence, atol=0.2)

def test_annealed_importance_sampling(self):
def tlp_1(x):
return -0.5 * x**2, ()

def tlp_2(x):
return (-0.5 * (x - 2) ** 2) / 16.0, ()

@jax.jit
def kernel(smc_state, seed):
smc_seed, seed = util.split_seed(seed, 2)

def inner_kernel(state, step, tlp_fn, seed):
f = jnp.array(step, state.dtype) / num_steps
hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn)
hmc_state, _ = fun_mc.hamiltonian_monte_carlo_step(
hmc_state,
tlp_fn,
step_size=f * 4.0 + (1.0 - f) * 1.0,
num_integrator_steps=1,
seed=seed,
)
return hmc_state.state, ()

smc_state, _ = smc.sequential_monte_carlo_step(
smc_state,
kernel=functools.partial(
smc.annealed_importance_sampling_kernel,
kernel=inner_kernel,
make_target_log_probability_fn=functools.partial(
fun_mc.geometric_annealing_path,
num_stages=num_steps,
initial_target_log_prob_fn=tlp_1,
final_target_log_prob_fn=tlp_2,
),
),
seed=smc_seed,
)

return (smc_state, seed), ()

num_steps = 1000
num_particles = 1000
init_seed, smc_seed = util.split_seed(_test_seed(), 2)
init_state = util.random_normal([num_particles], self._dtype, init_seed)

(smc_state, _), _ = fun_mc.trace(
(
smc.sequential_monte_carlo_init(
init_state,
weight_dtype=self._dtype,
),
smc_seed,
),
kernel,
num_steps,
)

weights = jnp.exp(smc_state.log_weights)
# 4 because tlp_2 has stddev of 4 while tlp_1 has stddev of 1.
self.assertAllClose(4.0, jnp.mean(weights), atol=0.1)

normed_weights = jax.nn.softmax(smc_state.log_weights)
mean = jnp.sum(normed_weights * smc_state.state)
variance = jnp.sum(normed_weights * (smc_state.state - mean) ** 2)
self.assertAllClose(2.0, mean, atol=0.3)
self.assertAllClose(16.0, variance, rtol=0.2)


@test_util.multi_backend_test(globals(), 'smc_test')
class SMCTest32(SMCTest):
Expand Down
21 changes: 19 additions & 2 deletions spinoffs/fun_mc/fun_mc/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ============================================================================
"""Various types used in FunMC."""

from typing import Callable, TypeAlias, TypeVar
from typing import Callable, Protocol, TypeAlias, TypeVar, runtime_checkable

import jaxtyping
from fun_mc import backend
Expand All @@ -29,6 +29,7 @@
'FloatScalar',
'Int',
'IntScalar',
'PotentialFn',
'runtime_typed',
'Seed',
]
Expand All @@ -42,8 +43,24 @@
BoolScalar: TypeAlias = bool | Bool[Array, '']
IntScalar: TypeAlias = int | Int[Array, '']
FloatScalar: TypeAlias = float | Float[Array, '']

F = TypeVar('F', bound=Callable)
_Extra = TypeVar('_Extra')


@runtime_checkable
class PotentialFn(Protocol[_Extra]):
"""Maps state to an array of float.
If the state has leading dimension, the same dimension is present in the
returned values as well.
"""

def __call__(
self,
*args,
**kwargs,
) -> tuple[Float[Array, '...'], _Extra]:
"""Potential function."""


def runtime_typed(f: F) -> F:
Expand Down

0 comments on commit 0a56cac

Please sign in to comment.