Metadata-Version: 2.1
Name: bayesnf
Version: 0.1.1
Summary: Scalable spatiotemporal prediction with Bayesian neural fields
Keywords: 
Author-email: bayesnf authors <bayesnf@google.com>
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Intended Audience :: Science/Research
Requires-Dist: flax
Requires-Dist: jax>=0.4.6
Requires-Dist: jaxtyping
Requires-Dist: numpy
Requires-Dist: optax
Requires-Dist: pandas
Requires-Dist: tensorflow-probability[jax]>=0.19.0
Requires-Dist: pytest ; extra == "dev"
Requires-Dist: pytest-xdist ; extra == "dev"
Requires-Dist: pylint>=2.6.0 ; extra == "dev"
Requires-Dist: pyink ; extra == "dev"
Requires-Dist: tqdm ; extra == "dev"
Requires-Dist: mkdocs==1.5.3 ; extra == "docs"
Requires-Dist: mkdocstrings[python]>=0.18 ; extra == "docs"
Requires-Dist: mkdocs-material==8.2.6 ; extra == "docs"
Requires-Dist: pymdown-extensions==9.4 ; extra == "docs"
Requires-Dist: mknotebooks==0.7.1 ; extra == "docs"
Requires-Dist: pytkdocs_tweaks==0.0.8 ; extra == "docs"
Requires-Dist: jinja2==3.0.3 ; extra == "docs"
Requires-Dist: nbconvert==6.5.0 ; extra == "docs"
Requires-Dist: nbformat==5.4.0 ; extra == "docs"
Requires-Dist: pygments==2.14.0 ; extra == "docs"
Project-URL: changelog, https://github.com/google/bayesnf/blob/main/CHANGELOG.md
Project-URL: homepage, https://github.com/google/bayesnf
Project-URL: repository, https://github.com/google/bayesnf
Provides-Extra: dev
Provides-Extra: docs

# Bayesian Neural Fields for Spatiotemporal Prediction

[![Unittests](https://github.com/google/bayesnf/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google/bayesnf/actions/workflows/pytest_and_autopublish.yml)
[![PyPI version](https://badge.fury.io/py/bayesnf.svg)](https://badge.fury.io/py/bayesnf)

*This is not an officially supported Google product.*

Spatially referenced time series (i.e., spatiotemporal) datasets are
ubiquitous in scientific, engineering, and business-intelligence
applications. This repository contains an implementation of the Bayesian
Neural Field (BayesNF) a novel spatiotemporal modeling method that
integrates hierarchical probabilistic modeling for accurate uncertainty
estimation with deep neural networks for high-capacity function
approximation.

Bayesian Neural Fields infer joint probability distributions over field
values at arbitrary points in time and space, which makes the model
suitable for many data-analysis tasks including spatial interpolation,
temporal forecasting, and variography. Posterior inference is conducted
using variationally learned surrogates trained via mini-batch stochastic
gradient descent for handling large-scale data.

## Installation

`bayesnf` can be installed from the Python Package Index
([PyPI](https://pypi.org/project/bayesnf/)) using:

```
python -m pip install .
```

The typical install time is 1 minute. This software is tested on Python 3.9
with a standard Debian GNU/Linux setup. The large-scale experiments in
`scripts/` were run using TPU v3-8 accelerators. For running BayesNF
locally on medium to large-scale data, a GPU is required at minimum.

## Documentation and Tutorials

Please visit <https://google.github.io/bayesnf>

## Quick start

```python

# Load a dataframe with "long" format spatiotemporal data.
df_train = pd.read_csv('chickenpox.5.train.csv',
  index_col=0, parse_dates=['datetime'])

# Build a BayesianNeuralFieldEstimator
model = BayesianNeuralFieldMAP(
  width=256,
  depth=2,
  freq='W',
  seasonality_periods=['M', 'Y'],
  num_seasonal_harmonics=[2, 10],
  feature_cols=['datetime', 'latitude', 'longitude'],
  target_col='chickenpox',
  observation_model='NORMAL',
  timetype='index',
  standardize=['latitude', 'longitude'],
  interactions=[(0, 1), (0, 2), (1, 2)])

# Fit the model.
model = model.fit(
  df_train,
  seed=jax.random.PRNGKey(0),
  ensemble_size=ensemble_size,
  num_epochs=num_epochs)

# Make predictions of means and quantiles on test data.
df_test = pd.read_csv('chickenpox.5.test.csv',
  index_col=0, parse_dates=['datetime'])

yhat, yhat_quantiles = model.predict(df_test, quantiles=(0.025, 0.5, 0.975))
```

