Skip to content

Commit

Permalink
Initiate repo for Feature-aligned N-BEATS
Browse files Browse the repository at this point in the history
  • Loading branch information
leejoonhun committed Nov 27, 2023
0 parents commit 08eb2c2
Show file tree
Hide file tree
Showing 16 changed files with 1,318 additions and 0 deletions.
167 changes: 167 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Data
cache/

# Logs
wandb/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
poetry.lock

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
84 changes: 84 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Feature-aligned N-BEATS

Official PyTorch Implementation of [Feature-aligned N-BEATS with Sinkhorn divergence](https://arxiv.org/abs/2305.15196).

## Data

Data should have form of `data/$SUPERDOMAIN/$DOMAIN.csv`, with three columns:

- `time` denotes the time index.
- `series` denotes the series index.
- `value` denotes the value of the time series at the given time index.

### Source

Data used in the paper is obtained from the following sources:

- [FRED](https://fred.stlouisfed.org)
- [Commodities](https://fred.stlouisfed.org/categories/32217) category
- [National Income & Product Accounts](https://fred.stlouisfed.org/categories/18) category
- [Interest Rates](https://fred.stlouisfed.org/categories/22) category
- [Exchange Rates](https://fred.stlouisfed.org/categories/15) category
- [NCEI](https://ncei.noaa.gov) (Only 2020s data are used)
- `"TEMP", "STP", "WDSP", "PRCP"` columns from [Global Surface Summary of the Day - GSOD](https://ncei.noaa.gov/metadata/geoportal/rest/metadata/item/gov.noaa.ncdc:C00516/html) dataset
- `"TAVG", "AWND", "PRCP"` columns from [Global Summary of the Month (GSOM), Version 1.0.3](https://ncei.noaa.gov/metadata/geoportal/rest/metadata/item/gov.noaa.ncdc:C00946/html) dataset
- `"TAVG", "AWND", "PRCP"` columns from [Global Summary of the Year (GSOY), Version 1](https://ncei.noaa.gov/metadata/geoportal/rest/metadata/item/gov.noaa.ncdc:C00947/html) dataset

## Usage

```shell
python main.py --source-domains $SOURCE_DOMAIN1 $SOURCE_DOMAIN2 ... \
--target-domain $TARGET_DOMAIN \
--forecast-horizon $FORECAST_HORIZON \
--lookback-multiple $LOOKBACK_MULTIPLE \
--model $MODEL \
--loss $LOSS \
--regularizer $REGULARIZER \
--temperature $TEMPERATURE \
--scaler $SCALER \
--metric $METRIC \
--learning-rate $LEARNING_RATE \
--num-lr-cycles $NUM_LR_CYCLES \
--batch-size $BATCH_SIZE \
--num-iters $NUM_ITERS \
--seed $SEED \
--dtype $DTYPE \
--data-size $DATA_SIZE
```

The detailed descriptions about the arguments are as follows:
| Argument | Description | Default |
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------ |
| `source_domains` | Source domains $\{\mathcal{D}^k\}_k$ | |
| `target_domain` | Target domain $\mathcal{D}^T$ | |
| `forecast_horizon` | Forecast horizon $\alpha$ | `10` |
| `lookback_multiple` | Lookback multiple $\beta/\alpha$ | `5` |
| `model` | Model architecture $\mathfrak{F}$ | `"NHiTS"` |
| `loss` | Forecasting loss function $\mathcal{L}$ | `"SMAPE"` |
| `regularizer` | Regularizer measure $\mathcal{L}_\mathrm{align}$ <br> NOTE: `"None"` for vanilla model | `"Sinkhorn"` |
| `temperature` | Regularizing temperature $\lambda$ | `1.0` |
| `scaler` | Normalizing function $\sigma$ | `"softmax"` |
| `metric` | Evaluation metric for validation and test | `"SMAPE"` |
| `learning_rate` | Learning rate $\eta$ | `2e-5` |
| `num_lr_cycles` | Number of learning rate cycles<br>NOTE: `torch.optim.lr_scheduler.CyclicLR(mode="triangular2")` is used ([ref](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CyclicLR.html)) | `50` |
| `batch_size` | Batch size $B$ | `2**12` |
| `num_iters` | Number of iterations | `1000` |
| `seed` | Random seed | `0` |
| `dtype` | Data type used for `torch` and `numpy` | `"float32"` |
| `data_size` | Fixed data size for each domain <br> NOTE: `"None"` to use all data | `75000` |

## Citation

```bib
@article{lee2023fanbeats,
title={Feature-aligned N-BEATS with Sinkhorn divergence},
author={Lee, Joonhun and Jeon, Myeongho and Kang, Myungjoo and Park, Kyunghyun},
journal={arXiv preprint arXiv:2305.15196},
year={2023}
}
```

## Acknowledgement

We would like to acknowledge the significant contributions of [the official N-BEATS implementation](https://github.com/ServiceNow/N-BEATS) to our work.
Our models are implemented based on their codebase.
72 changes: 72 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pickle
from typing import List, Optional, Tuple

from torch import cuda
from torch.utils import data as dt

from .dataloader import InfiniteDataLoader
from .dataset import TimeSeriesDataset
from .utils import DATA_DIR

MODES = ["train", "valid", "test"]


def get_dataloaders(
source_domains: List[str],
target_domain: str,
forecast_horizon: int,
lookback_horizon: int,
batch_size: int,
dtype: str,
fixed_data_size: Optional[int],
) -> Tuple[List[InfiniteDataLoader], List[dt.DataLoader], dt.DataLoader]:
trainloaders, validloaders = [], []
for domain in source_domains + [target_domain]:
superdomain, domain = domain.split("/")
cache_paths = {
mode: DATA_DIR
/ superdomain
/ "cache"
/ f"{domain}_{lookback_horizon}x{forecast_horizon}_{mode}.pkl"
for mode in MODES
}
if all(cache_paths[mode].exists() for mode in MODES):
datasets = []
for mode in MODES:
with open(cache_paths[mode], "rb") as f:
datasets.append(pickle.load(f))
else:
dataset = TimeSeriesDataset(
superdomain,
domain,
forecast_horizon,
lookback_horizon,
dtype,
fixed_data_size,
)
datasets = dt.random_split(dataset, [0.7, 0.1, 0.2])
for i, mode in enumerate(MODES):
cache_paths[mode].parent.mkdir(parents=True, exist_ok=True)
with open(cache_paths[mode], "wb") as f:
pickle.dump(datasets[i], f)
trainloaders.append(
InfiniteDataLoader(
datasets[0], batch_size=batch_size, num_workers=cuda.device_count() * 4
)
)
validloaders.append(
dt.DataLoader(
datasets[1],
batch_size=batch_size,
shuffle=False,
num_workers=cuda.device_count() * 4,
)
)
_ = trainloaders.pop(), validloaders.pop()
testloader = dt.DataLoader(
datasets[2],
batch_size=batch_size,
shuffle=False,
num_workers=cuda.device_count() * 4,
)
return trainloaders, validloaders, testloader
36 changes: 36 additions & 0 deletions data/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from torch.utils import data as dt


class InfiniteDataLoader:
def __init__(self, dataset: dt.Subset, batch_size: int, num_workers: int):
batch_sampler = InfiniteSampler(
dt.BatchSampler(
dt.RandomSampler(dataset, replacement=True, num_samples=batch_size),
batch_size,
drop_last=True,
)
)
self.iterator = iter(
dt.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
)
)

def __iter__(self):
while True:
yield next(self.iterator)

def __len__(self) -> float:
return torch.inf


class InfiniteSampler:
def __init__(self, sampler: dt.Sampler):
self.sampler = sampler

def __iter__(self):
while True:
yield from iter(self.sampler)
Loading

0 comments on commit 08eb2c2

Please sign in to comment.