Metadata-Version: 2.4
Name: siglip-kernel
Version: 0.1.1
Summary: Fused Triton kernels for memory-efficient SigLIP and SigLIP+ training
License: Apache-2.0
Project-URL: Repository, https://github.com/avocardio/siglip-kernel
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Requires-Dist: torch>=2.1.0
Requires-Dist: triton>=3.0.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"

# siglip-kernel

Fused Triton kernels for memory-efficient SigLIP training. The first public fused GPU kernel for the sigmoid contrastive loss: tile-by-tile computation that never materializes the `B × B` logit matrix, reducing peak loss memory from `O(B²)` to `O(B · D)` and enabling per-device batch sizes that the standard implementation cannot allocate.

On Hopper (GH200/H100) at `B=65,536` in BF16 the kernel runs the full loss forward+backward in 48.8 ms versus the eager cuBLAS reference's 102.0 ms (2.1× faster) on 1.4 GB versus 33 GB; the minimum-memory configuration (`chunk_size=1024`) uses 640 MB (52× less) at 54.3 ms. It beats even a `torch.compile`d reference from `B=32,768` up and reaches `B=262,144` on 5.6 GB where both references OOM. An opt-in FP8 forward (`backend="chunked_bf16_fp8"`) adds ~5% with gradient error below the BF16 floor. See the [paper](https://github.com/avocardio/fusedsiglip_paper) for full benchmarks and validation (3-epoch CC12M pre-training, 500-step LiT, parity tests).

## Install

```bash
git clone https://github.com/avocardio/siglip-kernel.git
cd siglip-kernel
pip install -e .
```

Requires PyTorch ≥ 2.1 and Triton ≥ 3.0.

## Usage

```python
import torch
from siglip_kernel import fused_siglip_loss

img = torch.randn(8192, 768, device="cuda", dtype=torch.bfloat16)
txt = torch.randn(8192, 768, device="cuda", dtype=torch.bfloat16)
img = torch.nn.functional.normalize(img, dim=-1)
txt = torch.nn.functional.normalize(txt, dim=-1)
log_temp = torch.tensor(2.302, device="cuda", requires_grad=True)  # ln(10)
bias     = torch.tensor(-10.0, device="cuda", requires_grad=True)

loss = fused_siglip_loss(img, txt, log_temp, bias)
loss.backward()
```

A dtype-aware router selects the right backend automatically: chunked-BF16 (cuBLAS GEMM + Triton fused BCE with the temperature/bias affine folded in-kernel) for `bfloat16`/`float16`, fully-fused Triton for `float32`. Pass `backend=` to override (`"chunked_bf16"`, `"chunked_bf16_fp8"`, `"chunked_epilogue"`, `"fused"`, `"hopper"`, `"fused_v2"`, or `"reference"`). The chunk size is adaptive (speed-optimal `c=4096` on SM90+); pass `chunk_size=1024` to the loss functions directly for the minimum-memory configuration. The forward path is free of host syncs and CUDA-graph capturable.

## SigLIP+ (multi-positive)

`fused_siglip_plus_loss` is the same chunked fused kernel applied to SigLIP+, the multi-positive formulation where all image-text pairs sharing a group identifier are positives and positive mass is normalized per anchor (1/k weighting, both directions averaged when symmetric). Logits may be rectangular (`M` captions for `N` images); with unique group ids it reduces to standard SigLIP.

```python
from siglip_kernel import fused_siglip_plus_loss, SigLIPPlusLoss

# functional: pre-normalized features, group ids mark positives
loss = fused_siglip_plus_loss(img, txt, log_temp, bias, img_gid, txt_gid)

# or the drop-in module (own logit_scale/bias, reference-compatible dict output)
loss_fn = SigLIPPlusLoss(temperature=0.07)
out = loss_fn([img_raw, txt_raw], img_gid=img_gid, txt_gid=txt_gid)
```

Verified exact against the reference implementation (FP32 gradients to 2e-7) across symmetric/asymmetric, soft-label, rectangular shapes, and rows without positives. The eager reference materializes the logit, mask, and weight matrices; the fused path stays `O(chunk · M)`.

## Multi-GPU

`DistributedSigLIPLoss` is a drop-in for OpenCLIP's `SigLipLoss` signature (it receives the already-exponentiated `logit_scale`). The ring strategy exchanges text blocks peer-to-peer, runs every `local_B × local_B` block through the fused kernel, and propagates text gradients back around the ring in backward — with communication overlapped behind block compute in both passes (`overlap=True`, default):

```python
from siglip_kernel.distributed import DistributedSigLIPLoss

loss_fn = DistributedSigLIPLoss(strategy="ring", rank=rank, world_size=world_size)
loss = loss_fn(image_features, text_features, logit_scale, logit_bias)
```

Verified exact against a single-process full-batch reference (FP32 gradients to 1.5e-6) on 4 and 8 GH200s. Loss step time (fwd+bwd, D=768, BF16):

| global B | 1 GPU | 4 GPU (NVLink) | 8 GPU (2 nodes) |
|----------|-------|----------------|------------------|
| 65,536   | 46.9 ms | 13.1 ms | 10.1 ms |
| 131,072  | 193.4 ms | — | 31.5 ms |

The comm/compute overlap is worth ~5% on NVLink and 13–21% across nodes. On Slurm/HPE-Slingshot clusters, load the `aws-ofi-nccl` plugin — without it NCCL falls back to TCP and multi-node ring steps run ~7× slower.

## GPU support

| Architecture | SM | Status |
|--------------|----|--------|
| Ampere (A100) | SM80 | Tested, autotuned tiles |
| Ada (RTX 4090, L40) | SM89 | Tested, FP8 path available |
| Hopper (H100, H200) | SM90 | Tested, persistent kernel + TMA |
| Blackwell (B200) | SM100 | Tested via Triton 3.7, no code changes |

## OpenCLIP integration

For users on OpenCLIP, a no-dependency pure-PyTorch chunked variant of `SigLipLoss` is proposed upstream as [PR #1145](https://github.com/mlfoundations/open_clip/pull/1145). That PR addresses the memory side only; for the speed-up too, install this package and use `fused_siglip_loss` directly.

## Tests

```bash
pip install -e .[dev]
pytest tests/
```

119 tests cover correctness (forward + backward parity vs FP32 reference), gradient stability, diagonal-offset placement for distributed blocks, FP8/epilogue backends, and edge shapes (odd batches, uneven chunks, offsets crossing chunk boundaries). The multi-process NCCL paths are additionally validated against a full-batch reference on 4 and 8 GH200s.

## Citation

Paper source at [avocardio/fusedsiglip_paper](https://github.com/avocardio/fusedsiglip_paper); preprint forthcoming.

## License

Apache-2.0.
