Skip to content
This repository has been archived by the owner on Apr 4, 2024. It is now read-only.

Commit

Permalink
Centralize binary search
Browse files Browse the repository at this point in the history
  • Loading branch information
hssahota2 committed Apr 3, 2024
1 parent fe16c26 commit 1845272
Showing 1 changed file with 55 additions and 62 deletions.
117 changes: 55 additions & 62 deletions python/selfie-lib/selfie_lib/ArrayMap.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
from collections.abc import Set, Iterator, Mapping
from typing import List, TypeVar, Union, Any
from typing import List, TypeVar, Union, Any, Callable, Optional, Generator
from abc import abstractmethod, ABC
from functools import total_ordering

T = TypeVar("T")
V = TypeVar("V")
K = TypeVar("K")


@total_ordering
class Comparable:
def __init__(self, value):
self.value = value
class BinarySearchUtil:
@staticmethod
def binary_search(
data, item, compare_func: Optional[Callable[[Any, Any], int]] = None
) -> int:
low, high = 0, len(data) - 1
while low <= high:
mid = (low + high) // 2
mid_val = data[mid] if not isinstance(data, ListBackedSet) else data[mid]
comparison = (
compare_func(mid_val, item)
if compare_func
else (mid_val > item) - (mid_val < item)
)

def __lt__(self, other: Any) -> bool:
if not isinstance(other, Comparable):
return NotImplemented
return self.value < other.value
if comparison < 0:
low = mid + 1
elif comparison > 0:
high = mid - 1
else:
return mid # item found
return -(low + 1) # item not found

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Comparable):
return NotImplemented
return self.value == other.value
@staticmethod
def default_compare(a: Any, b: Any) -> int:
"""Default comparison function for binary search, with special handling for strings."""
if isinstance(a, str) and isinstance(b, str):
a, b = a.replace("/", "\0"), b.replace("/", "\0")
return (a > b) - (a < b)


class ListBackedSet(Set[T], ABC):
Expand All @@ -31,25 +45,14 @@ def __len__(self) -> int: ...
@abstractmethod
def __getitem__(self, index: Union[int, slice]) -> Union[T, List[T]]: ...

@abstractmethod
def __iter__(self) -> Iterator[T]: ...

def __contains__(self, item: Any) -> bool:
return self._binary_search(item) >= 0

def _binary_search(self, item: Any) -> int:
low = 0
high = len(self) - 1
while low <= high:
mid = (low + high) // 2
try:
mid_val = self[mid]
if mid_val < item:
low = mid + 1
elif mid_val > item:
high = mid - 1
else:
return mid # item found
except TypeError:
raise ValueError(f"Cannot compare items due to a type mismatch.")
return -(low + 1) # item not found
return BinarySearchUtil.binary_search(self, item)


class ArraySet(ListBackedSet[K]):
Expand Down Expand Up @@ -80,59 +83,49 @@ def __getitem__(self, index: Union[int, slice]) -> Union[K, List[K]]:
return self.__data[index]

def plusOrThis(self, element: K) -> "ArraySet[K]":
if element in self:
index = self._binary_search(element)
if index >= 0:
return self
else:
insert_at = -(index + 1)
new_data = self.__data[:]
new_data.append(element)
new_data.sort(key=Comparable)
new_data.insert(insert_at, element)
return ArraySet.__create(new_data)


class ArrayMap(Mapping[K, V]):
def __init__(self, data=None):
if data is None:
self.__data = []
else:
self.__data = data
__data: List[Union[K, V]]

def __init__(self):
raise NotImplementedError("Use ArrayMap.empty() or other class methods instead")

@classmethod
def __create(cls, data: List[Union[K, V]]) -> "ArrayMap[K, V]":
instance = cls.__new__(cls)
instance.__data = data
return instance

@classmethod
def empty(cls) -> "ArrayMap[K, V]":
if not hasattr(cls, "__EMPTY"):
cls.__EMPTY = cls([])
cls.__EMPTY = cls.__create([])
return cls.__EMPTY

def __getitem__(self, key: K) -> V:
index = self._binary_search_key(key)
if index >= 0:
return self.__data[2 * index + 1]
return self.__data[2 * index + 1] # type: ignore
raise KeyError(key)

def __iter__(self) -> Iterator[K]:
return (self.__data[i] for i in range(0, len(self.__data), 2))
return (self.__data[i] for i in range(0, len(self.__data), 2)) # type: ignore

def __len__(self) -> int:
return len(self.__data) // 2

def _binary_search_key(self, key: K) -> int:
def compare(a, b):
"""Comparator that puts '/' first in strings."""
if isinstance(a, str) and isinstance(b, str):
a, b = a.replace("/", "\0"), b.replace("/", "\0")
return (a > b) - (a < b)

low, high = 0, len(self.__data) // 2 - 1
while low <= high:
mid = (low + high) // 2
mid_key = self.__data[2 * mid]
comparison = compare(mid_key, key)
if comparison < 0:
low = mid + 1
elif comparison > 0:
high = mid - 1
else:
return mid # key found
return -(low + 1) # key not found
keys = [self.__data[i] for i in range(0, len(self.__data), 2)]
return BinarySearchUtil.binary_search(keys, key)

def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
index = self._binary_search_key(key)
Expand All @@ -142,12 +135,12 @@ def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
new_data = self.__data[:]
new_data.insert(insert_at * 2, key)
new_data.insert(insert_at * 2 + 1, value)
return ArrayMap(new_data)
return ArrayMap.__create(new_data)

def minus_sorted_indices(self, indices: List[int]) -> "ArrayMap[K, V]":
new_data = self.__data[:]
adjusted_indices = [i * 2 for i in indices] + [i * 2 + 1 for i in indices]
adjusted_indices.sort()
for index in reversed(adjusted_indices):
adjusted_indices.sort(reverse=True)
for index in adjusted_indices:
del new_data[index]
return ArrayMap(new_data)
return ArrayMap.__create(new_data)

0 comments on commit 1845272

Please sign in to comment.