Metadata-Version: 2.1
Name: adamw_bf16
Version: 0.0.2
Summary: AdamW Optimizer for bfloat16
Author-email: APJC <apjc@usa.com>
Project-URL: Homepage, https://github.com/AmericanPresidentJimmyCarter/adamw-bf16
Project-URL: Issues, https://github.com/AmericanPresidentJimmyCarter/adamw-bf16/issues
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: GNU Affero General Public License v3
Classifier: Operating System :: OS Independent
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE.txt
Requires-Dist: torch >=2.1.0

# AdamW optimizer for bfloat16 in PyTorch

This is a version of the AdamW optimizer for use in torch that achieves the same results in ViT training tests as training with the weights in float32 with operations in float32 or bfloat16 (autocast). By keeping your weights in bfloat16, you can save approximately half the weights they would normally take up in memory. It uses [stochastic rounding and a correction term](https://arxiv.org/pdf/2010.06192.pdf) to achieve this.

There is a small (~10-20%) performance hit depending on your hardware.

To install:

```sh
pip install adamw-bf16
```

To use:

```py
from adamw_bf16 import AdamWBF16

model = model.to(dtype=torch.bfloat16)
optimizer = AdamWBF16(model.parameters(), ...)

# Train your model
```

This repository was created using code from the following two projects. It was found that insights from both could be combined to match the performance with the model weights stored in float32.

- [adamw_bfloat16](https://github.com/arogozhnikov/adamw_bfloat16)
- [OneTrainer](https://github.com/Nerogar/OneTrainer)
