Metadata-Version: 2.1
Name: arlatentsde
Version: 0.2.0
Summary: Amortized Reparametrization for Continuous Time Autoencoders (ARCTA)
License: GNUv3
Author: Kevin Course
Requires-Python: >=3.10,<3.12
Classifier: License :: Other/Proprietary License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Requires-Dist: beartype (>=0.12.0,<0.13.0)
Requires-Dist: jaxtyping (>=0.2.14,<0.3.0)
Requires-Dist: pytorch-lightning (>=2.0.1.post0,<3.0.0)
Requires-Dist: tensorboard (>=2.12.1,<3.0.0)
Requires-Dist: torch (>=2.0.1,<3.0.0)
Requires-Dist: torchaudio (>=2.0.2,<3.0.0)
Requires-Dist: torchdiffeq (>=0.2.3,<0.3.0)
Requires-Dist: torchsde (>=0.2.5,<0.3.0)
Requires-Dist: torchvision (>=0.15.2,<0.16.0)
Description-Content-Type: text/markdown

# Amortized reparametrization: Efficient and Scalable Variational Inference for Latent SDEs

Accompanying code for the [NeurIPS 2023
paper](https://openreview.net/forum?id=5yZiP9fZNv)
by Kevin Course and Prasanth B. Nair.

**Tutorials and documentation coming soon!**

## 1. Installation

### Installing the package

The package can be installed from PyPI:

```bash
pip install arlatentsde
```

### Reproducing the experiment environment

We ran experiments on a Linux machine with CUDA 11.8.
We used [poetry](https://github.com/python-poetry/poetry) to manage dependencies.

If you prefer a different environment manager, all dependencies are listed
in the `pyproject.toml`.

To reproduce the experiment environment, first navigate to branch named
`neurips-freeze`.
Then install all optional dependencies required to run experiments,

```bash
poetry install --with dev,exps
```

To download all pretrained models, datasets, and figures we use [repopacker](https://github.com/coursekevin/repopacker):

```bash
repopacker download models-data-figs.zip
repopacker unpack models-data-figs.zip
```

## 2. Usage

The numerical studies can be rerun from the experiments
directory using the command-line script `main.py`. All numerical
studies follow the same basic structure:
(i) generate / download,
(ii) train model, and
(iii) post process for plots and tables.

The script has the following syntax:

```bash
python main.py [experiment] [action]
```

The choices of experiments and actions are provided below:

- Experiments:
  - `predprey`: Orders of magnitude magnitude fewer NFEs experiment
  - `lorenz`: Adjoint instabilities experiment
  - `mocap`: Motion capture benchmark
  - `nsde-video`: Neural SDE from video experiment
  - `grad-variance`: Gradient variance experiment
- Actions:
  - `get-data`: Download / generate data
  - `train`: Train models
  - `post-process`: Post process for plots and tables

## 3. Reference

Course, K., Nair, P.B. Amortized Reparametrization: Efficient and Scalable Variational Inference for Latent SDEs.  
In Proc. Advances in Neural Information Processing Systems, (2023).

```
@inproceedings{
course2023amortized,
title={Amortized Reparametrization: Efficient and Scalable Variational Inference for Latent {SDE}s},
author={Kevin Course and Prasanth B. Nair},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=5yZiP9fZNv}
}
```

