Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump version #15

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cascades/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# A new PyPI release will be pushed everytime `__version__` is increased
# When changing this, also update the CHANGELOG.md
__version__ = '0.3.2'
__version__ = '0.4.0'

from cascades._src.distributions.base import UniformCategorical
from cascades._src.distributions.gpt import GPT
Expand Down
64 changes: 64 additions & 0 deletions cascades/_src/distributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
import math
from typing import Any, Optional, Tuple, Union

import cachetools
import jax
import jax.numpy as jnp
import numpy as np
from numpyro import distributions as np_dists
import shortuuid

DEFAULT_TIMEOUT = 60

Expand Down Expand Up @@ -234,3 +237,64 @@ def log_prob(self, value):

def support(self):
return self.options


# Higher order distributons


def _rng_hash(rng):
"""Hash an rng key."""
if isinstance(rng, int):
return hash(rng)
else:
return hash(tuple(np.asarray(rng)))


def _mem_sample_key(self, rng):
"""Cache key for Mem distribution."""
rng = _rng_hash(rng)
h = hash((rng, self))
return h


def _mem_score_key(self, value):
"""Cache key for Mem distribution."""
h = hash((value, self))
return h


# TODO(ddohan): Consider sharing cache across instances.
@dataclasses.dataclass(frozen=True, eq=True)
class Mem(Distribution):
"""Add a cache to a distribution so that repeated calls are memoized."""
dist: Optional[Distribution] = dataclasses.field(repr=True, default=None)

# Mem should not be equal unless they are really the same
# generate a unique UID to ensure this property.
uid: str = dataclasses.field(
repr=True, default_factory=lambda: shortuuid.uuid()[:8])

@cachetools.cached(cache=cachetools.LRUCache(maxsize=100_000),
key=_mem_sample_key)
def sample(self, rng):
return self.dist.sample(rng=rng)

@cachetools.cached(cache=cachetools.LRUCache(maxsize=100_000),
key=_mem_score_key)
def score(self, value):
return self.dist.score(value)


@dataclasses.dataclass(frozen=True, eq=True)
class Lambda(Distribution):
"""Wrap a function as distribution."""
fn: Any = dataclasses.field(
default=None, hash=None) # TODO(ddohan): Add type of callable.

def sample(self, rng):
del rng
value = self.fn()
return RandomSample(value=value, dist=self, log_p=None)

def score(self):
raise NotImplementedError('Scoring from Lambda is not available.')
74 changes: 74 additions & 0 deletions cascades/_src/distributions/base_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2022 The cascades Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for base distributions."""

import dataclasses
import random

from absl.testing import absltest
from cascades._src.distributions import base


@dataclasses.dataclass(eq=True, frozen=True)
class RandomFactor(base.Distribution):
"""Randomized likelihood for testing purposes."""

def sample(self, rng):
del rng
return base.RandomSample(value=None, log_p=self.score(None))

def score(self, value):
del value
return random.randint(0, 100_000_000)


class BaseTest(absltest.TestCase):

def test_lambda(self):
fn = lambda: 5
dist = base.Lambda(fn=fn)
sample = dist.sample(0)
self.assertEqual(5, sample.value)

def test_mem_lambda(self):
"""Test memoizing a lambda distribution."""
fn = lambda: random.randint(0, 100_000_000)
dist = base.Lambda(fn=fn)
dist = base.Mem(dist=dist)
v1 = dist.sample(0).value
v2 = dist.sample(0).value
v3 = dist.sample(1).value
self.assertEqual(v1, v2)
self.assertNotEqual(v1, v3)

def test_mem_sample_and_score(self):
"""Test memoizing a randomized sample & score distribution."""
dist = RandomFactor()
dist = base.Mem(dist=dist)
v1 = dist.sample(0).score
v2 = dist.sample(0).score
v3 = dist.sample(1).score
self.assertEqual(v1, v2)
self.assertNotEqual(v1, v3)

v1 = dist.score('abc')
v2 = dist.score('abc')
v3 = dist.score('xyz')
self.assertEqual(v1, v2)
self.assertNotEqual(v1, v3)


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
keywords = []

# pip dependencies of the project
dependencies = ["jax[cpu]", "numpyro", "immutabledict", "openai"]
dependencies = ["cachetools", "shortuuid", "jax[cpu]", "numpyro", "immutabledict", "openai"]

# This is set automatically by flit using `cascades.__version__`
dynamic = ["version"]
Expand Down