Metadata-Version: 2.4
Name: bloom-torch
Version: 2.0.0
Summary: Bloom-filter–accelerated clustering and set structures in PyTorch (not related to the BLOOM language model).
Author: bloom-torch contributors
License: MIT
Project-URL: Homepage, https://github.com/your-org/bloom-torch
Project-URL: Documentation, https://github.com/your-org/bloom-torch#readme
Project-URL: Repository, https://github.com/your-org/bloom-torch
Project-URL: Issues, https://github.com/your-org/bloom-torch/issues
Keywords: bloom-filter,pytorch,clustering,k-means,bloom-matrix
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: ruff>=0.4; extra == "dev"
Dynamic: license-file

# bloom-torch

[![PyPI version](https://img.shields.io/pypi/v/bloom-torch.svg)](https://pypi.org/project/bloom-torch/)
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)

**bloom-torch** is a compact PyTorch library for:

- **Bloom filters** — universal-hash family and batched hashing on CPU/CUDA/MPS (`BloomHasher`).
- **Bloom matrices** — bit-packed **m × |E|** encodings with vectorised AND/OR lookup (`TorchBloomMatrix`).
- **BloomKMeans** — Lloyd-style K-means where assignment can use **top‑k cluster candidates** when **K** is large, plus a **`TorchBloomMatrix`** routing index over cluster → elements (`BloomKMeans`).

Typical uses: fast approximate **set membership**, **label → element** routing matrices, and **accelerated clustering** over large **n × K** assignment steps.

---

## Installation

```bash
pip install bloom-torch
```

**Requirements:** Python **≥ 3.10**, PyTorch **≥ 2.0**.

Optional dev tools:

```bash
pip install bloom-torch[dev]
```

---

## `TorchBloomMatrix` — Bloom-encoded sets (start here)

A **Bloom matrix** is an **m × E** bit table. **E = `num_elements`** is the universe of **element IDs** (columns, e.g. item or token indices). Each **label** (an int) maps through **k** hash functions to **k** rows; inserting **(label, element)** **OR**s bits at those positions. A **lookup** **AND**s the **k** rows for that label → a conservative **`[E]`** bool mask over elements (**no false negatives**; controlled **false positives** via **m**, **k**, and load).

### Example

Combine two “labels” (e.g. routes or topics) and ask which elements are allowed by **either** label:

```python
import torch
from bloom_torch import TorchBloomMatrix

# Label → set of element indices in universe [0, 100)
label_to_elements = {
    0: [1, 5, 9],
    1: [5, 10, 11],
    2: [0, 99],
}

bm = TorchBloomMatrix.from_label_index(
    label_to_elements,
    num_elements=100,
    p=0.05,
    seed=0,
    device="cpu",
)

# At query time: activate labels 0 and 2, OR their candidate sets
active_labels = torch.tensor([0, 2], dtype=torch.int64)
mask = bm.batch_lookup_mask(active_labels)   # [100] bool — True ≈ "allowed"
print("candidates (approx.):", int(mask.sum().item()))
```

`from_label_index` picks **m** and **k** from the average set size and target FP rate **p**, inserts every set, then **`pack()`** so lookups use fast **int64** bitwise ops.

### How a matrix is built (manual path)

1. **Constructor** — `TorchBloomMatrix(num_elements, m, k, seed=..., device=...)`: allocates a bool matrix **`[m, E]`** and a `BloomHasher`.
2. **Inserts**
   - **`add(label_id, element_ids)`** — one label, 1-D tensor or list of element indices.
   - **`batch_add(label_ids, element_ids)`** — `label_ids` shape **`[L]`**, `element_ids` shape **`[L, T]`** (T slots per label; invalid slots should be out of range or masked by you before insert).
3. **`pack()`** — folds the bool matrix into **`matrix_packed`** shape **`[m, ⌈E/64⌉]`** int64, frees the bool buffer. After **`pack()`**, you cannot **`add()`** again unless you rebuild.

The **factory** **`TorchBloomMatrix.from_label_index(mapping, num_elements, p=0.01, ...)`** wraps: build empty matrix → **`add`** for each `(label, elements)` → **`pack()`** → return.

### Lookups

- **`lookup_mask(label_id)`** — single label → **`[E]`** bool.
- **`batch_lookup_mask(label_ids)`** — **`[L]`** labels → **OR** of their per-label AND results → **`[E]`** bool (common for “top‑k active labels” at runtime).

---

## `BloomKMeans` — and how it uses `TorchBloomMatrix`

`BloomKMeans` clusters rows of a matrix **`X`** (`[n, d]`) into **K** centroids. After **`fit`**, each row has a cluster id **`assignments_[i]`**. The library’s main link to **`TorchBloomMatrix`** is the **routing index**: **cluster id → set of row indices** (e.g. vocabulary tokens in that cluster), encoded as a Bloom matrix for compact storage and fast OR-of-AND queries over several active clusters.

### Example

**There is no `TorchBloomMatrix` (or any matrix object) in `BloomKMeans(...)`.**  
`BloomKMeans.__init__` only takes **hyperparameters**: `n_clusters`, `topk_cache`, `bm_fp_rate`, `routing_fp_rate`, `seed`. The **`TorchBloomMatrix` is constructed later**, inside **`build_routing_bloom`**, from **`assignments_`** produced by **`fit`**. `routing_fp_rate` controls the false‑positive budget for **that** matrix when it is built—you are not passing a matrix in.

```python
import torch
from bloom_torch import BloomKMeans

X = torch.randn(10_000, 128, dtype=torch.float32)

# No TorchBloomMatrix here — only K-means / Bloom *rates* and top-k cache size.
km = BloomKMeans(
    n_clusters=256,
    topk_cache=16,
    bm_fp_rate=0.01,
    routing_fp_rate=0.01,   # used when build_routing_bloom() calls from_label_index(..., p=...)
    seed=0,
)
km.fit(
    X,
    max_iters=20,
    use_bm_after=1,
    allow_bm_assign_small_k=False,
)

# HERE the TorchBloomMatrix is created (internally: TorchBloomMatrix.from_label_index(...))
routing_bm = km.build_routing_bloom(vocab_size=X.shape[0])
assert routing_bm is km.routing_bloom  # same object
# Decode: km.routing_bloom.batch_lookup_mask(cluster_ids)  -> [vocab_size] bool
```

### How `BloomKMeans` wires in `TorchBloomMatrix` (same as `bloom_kmeans.py`)

`BloomKMeans` **does not subclass** `TorchBloomMatrix`; it **imports** it, keeps **`self.routing_bloom`**, and **constructs** matrices where needed. All snippets below are from **`src/bloom_torch/bloom_kmeans.py`** (line numbers are for **v2.0.0**; search the file if they drift).

**1. Import**

```python
from .bloom_hash import BloomHasher
from .bloom_matrix import TorchBloomMatrix
```

**2. Field on the module — routing Bloom matrix (cluster → tokens), filled after `build_routing_bloom`**

```python
        # Populated after fit()
        self.centroids: Optional[Tensor] = None
        self.assignments_: Optional[Tensor] = None

        # Populated after build_routing_bloom()
        self.routing_bloom: Optional[TorchBloomMatrix] = None
```

**3. Main path — `build_routing_bloom`: invert `assignments_`, then `TorchBloomMatrix.from_label_index`**

Cluster id = Bloom **label**, token index = **element**. Result is stored on **`self.routing_bloom`** and returned.

```python
    def build_routing_bloom(self, vocab_size: int) -> TorchBloomMatrix:
        if self.assignments_ is None:
            raise RuntimeError("Call fit() before build_routing_bloom().")

        K = self.n_clusters
        assignments = self.assignments_
        if assignments.shape[0] > vocab_size:
            assignments = assignments[:vocab_size]
        device = assignments.device

        label_to_elements: dict[int, list[int]] = {c: [] for c in range(K)}
        for t in range(assignments.shape[0]):
            label_to_elements[int(assignments[t])].append(t)

        bm = TorchBloomMatrix.from_label_index(
            label_to_elements,
            num_elements=vocab_size,
            p=self.routing_fp_rate,
            seed=self.seed,
            device=device,
        )
        self.routing_bloom = bm
        return bm
```

**4. Secondary — `_build_kmeans_cache`: a different `TorchBloomMatrix` (token → top‑k cluster columns)**

Here **`num_elements=K`** (clusters), each **label** is a **token row**, **`batch_add`** encodes nearest-cluster ids per row. The current **`fit`** loop usually uses **`_assign_vectorized`** instead of consulting this object every step, but it shows the same **`TorchBloomMatrix(...)` + `batch_add`** pattern:

```python
    def _build_kmeans_cache(self, X: Tensor, centroids: Tensor) -> TorchBloomMatrix:
        topk_clusters = self._compute_topk(X, centroids)           # [n, topk]
        n, K = X.shape[0], centroids.shape[0]
        actual_topk = topk_clusters.shape[1]
        m = BloomHasher.optimal_m(actual_topk, self.bm_fp_rate)
        k = BloomHasher.optimal_k(self.bm_fp_rate)
        bm = TorchBloomMatrix(
            num_elements=K, m=m, k=k, seed=self.seed, device=X.device
        )
        token_ids = torch.arange(n, dtype=torch.int64, device=X.device)
        bm.batch_add(token_ids, topk_clusters)
        return bm
```

**5. At decode time (your code)** — use the matrix returned / stored by **`build_routing_bloom`**:

```python
# km.build_routing_bloom(vocab_size)  # already called after fit
cluster_ids = torch.tensor([3, 41, 107], device=km.routing_bloom.matrix_packed.device)  # example top-k clusters
mask = km.routing_bloom.batch_lookup_mask(cluster_ids)  # [vocab_size] bool
```

During **`fit`**, when **K** is large and the top‑k assignment path runs, **`_compute_topk`** + **`_assign_vectorized`** avoid a full **`n × K`** distance matrix each iteration; the **persisted** `TorchBloomMatrix` for **routing** is still the one from **`build_routing_bloom`** above.

---

## Applications that benefit

Use **bloom-torch** when your problem has **many discrete IDs** (tokens, items, users, keys) and you want **space‑efficient sets** or **cheaper nearest‑cluster work** inside a **PyTorch** stack (CPU / CUDA / MPS).

| Application pattern | What you use | Why it helps |
|---------------------|--------------|--------------|
| **Constrained decoding / logits masking** (e.g. LLMs) | `BloomKMeans` + `build_routing_bloom` | Partition the vocabulary (or items) into **K** clusters once; at each step, activate only **top‑k** clusters and **OR** their Bloom‑encoded token sets → small candidate set before full softmax or matmul. |
| **Large‑K K‑means on embeddings** | `BloomKMeans` (`allow_bm_assign_small_k` / large **K**) | Avoids materialising full **n × K** distances every iteration when **K** is huge; uses **top‑k** candidates per point plus vectorised distance chunks. |
| **Many‑to‑many label → item indices** | `TorchBloomMatrix` | Encode **millions** of “label → allowed elements” edges in a compact bit matrix; **batch_lookup_mask** composes several labels with **OR‑of‑AND** (multifilter style). |
| **GPU‑resident filters** | `BloomHasher` + `TorchBloomMatrix` | Same math as classic Bloom filters but **tensor** ops—no Python hot loops, same device as your model or data pipeline. |
| **Prototyping routing / gating** | `BloomKMeans` + routing BM | Fast iteration when the exact router can change often; Bloom **false positives** only widen candidates (safe if a downstream step re‑scores or clips). |

### When something else may be better

- **Very small K or n** — plain `argmin` over all centroids is simpler; the Bloom / top‑k machinery adds little.
- **You need exact cluster–token sets with zero extra candidates** — use a dense index or hash map; Bloom routing trades a controlled **false‑positive rate** for memory.
- **Non‑vector / non‑torch pipelines** — classic **mmh3** + CPU Bloom libraries may be enough; **bloom-torch** pays off when tensors and accelerators are already in the loop.

---

> **Not the BLOOM LLM**  
> This package is **not** related to the [**BLOOM**](https://huggingface.co/bigscience/bloom) multilingual language model or other “BLOOM” names in Hugging Face. Here **Bloom** means [**Bloom filters**](https://en.wikipedia.org/wiki/Bloom_filter) (Burton H. Bloom, 1970).

---

## Public API (v2.0.0)

| Symbol | Role |
|--------|------|
| `BloomHasher` | **k** independent hash functions over int64 IDs; tensor-native. |
| `TorchBloomMatrix` | Add / lookup with optional int64 **packing** for smaller memory and faster decode-style ANDs. |
| `BloomKMeans` | `fit`, `predict`, `build_routing_bloom`; optional `allow_bm_assign_small_k` and `use_bm_after` control when full `[n, K]` Lloyd runs vs top‑k assignment. |

Import package:

```python
import bloom_torch
print(bloom_torch.__version__)
```

---

## Reference

The Bloom matrix / multi-filter idea used by `TorchBloomMatrix` is related to:

Francesco Concas, Pengfei Xu, Mohammad A. Hoque, Jiaheng Lu, and Sasu Tarkoma. 2020. **Multiple Set Matching with Bloom Matrix and Bloom Vector**. *ACM Transactions on Knowledge Discovery from Data*. [https://dl.acm.org/doi/fullHtml/10.1145/3372409](https://dl.acm.org/doi/fullHtml/10.1145/3372409)

---

## Development

From a clone of this repository:

```bash
pip install -e ".[dev]"
python -c "from bloom_torch import BloomKMeans; import torch; print('ok')"
```

Build wheels / sdist locally:

```bash
pip install build
python -m build
```

Releases are intended to be published from GitHub Actions (see `.github/workflows/publish.yml`) using **PyPI trusted publishing**. Set `[project.urls]` in `pyproject.toml` to your real repository before tagging.

---

## Versioning

- **2.0.0** — fixes a critical `BloomKMeans` clustering bug where centroid updates used cluster sums instead of cluster means, which could cause non-convergence and cluster collapse on larger vocabularies. It also removes the unfinished dense-routing API from `BloomKMeans`, keeps routing on `build_routing_bloom`, and documents the Bloom Matrix reference.
- **1.0.1** — clarified the BloomKMeans ↔ `TorchBloomMatrix` API flow.
- **1.0.0** — stabilised the three-module surface.

Semver bumps signal API or behaviour changes users should read in the changelog (add `CHANGELOG.md` when you start cutting releases).

---

## License

MIT — see [`LICENSE`](LICENSE).
