Skip to content

feat: Adan optimizer integration #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: update adan transformation
  • Loading branch information
Benjamin-eecs committed Jul 23, 2023
commit 53d2bd02f4f7ca2190466e9abff7fde98d53f8f6
134 changes: 94 additions & 40 deletions torchopt/transform/scale_by_adan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,40 @@

from torchopt import pytree
from torchopt.base import GradientTransformation
from torchopt.transform.utils import update_moment
from torchopt.typing import OptState, Updates
from torchopt.transform.utils import inc_count, tree_map_flat, update_moment
from torchopt.typing import OptState, Params, Updates


__all__ = [
'scale_by_adan',
]


class ScaleByAdanState(NamedTuple):
"""State for the Adan algorithm."""

count: OptState
mu: Updates
nu: Updates
delta: Updates
grad_tm1: Updates
count: OptState


def _adan_bias_correction(
moment: Updates,
decay: float,
count: OptState,
*,
already_flattened: bool = False,
) -> Updates:
"""Perform bias correction. This becomes a no-op as count goes to infinity."""

def f(t: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return t.div(1 - pow(decay, c))

if already_flattened:
return tree_map_flat(f, moment, count)
return pytree.tree_map(f, moment, count)


def scale_by_adan(
Expand All @@ -55,8 +77,10 @@ def scale_by_adan(
b1 (float, optional): Decay rate for the exponentially weighted average of gradients.
(default: :const:`0.98`)
b2 (float, optional): Decay rate for the exponentially weighted average of difference of
gradients.
b3: Decay rate for the exponentially weighted average of the squared term.
gradients.
(default: :const:`0.92`)
b3 (float, optional): Decay rate for the exponentially weighted average of the squared term.
(default: :const:`0.99`)
eps (float, optional): Term added to the denominator to improve numerical stability.
(default: :const:`1e-8`)
eps_root (float, optional): Term added to the denominator inside the square-root to improve
Expand Down Expand Up @@ -134,61 +158,95 @@ def init_fn(params: Params) -> OptState:
params,
)
nu = tree_map( # second moment
torch.zeros_like,
lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
params,
)
delta = tree_map( # EWA of Difference of gradients
lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
params,
)
grad_tm1 = tree_map(
torch.zeros_like,
lambda t: torch.zeros_like(
t,
),
params,
) # Previous gradient
return ScaleByAdanState(
count=torch.zeros([], torch.int32),
mu=mu,
nu=nu,
delta=delta,
grad_tm1=grad_tm1,
count=zero,
)

def update_fn(updates, state, params=None):
del params
def update_fn(
updates: Updates,
state: OptState,
*,
params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
) -> tuple[Updates, OptState]:
diff = pytree.lax.cond(
state.count != 0,
lambda X, Y: pytree.tree_map(lambda x, y: x - y, X, Y),
lambda X, _: pytree.tree_map(torch.zeros_like, X),
lambda X, Y: tree_map(lambda x, y: x - y, X, Y),
lambda X, _: tree_map(torch.zeros_like, X),
updates,
state.grad_tm1,
)

grad_prime = pytree.tree_map(lambda g, d: g + b2 * d, updates, diff)
grad_prime = tree_map(lambda g, d: g + b2 * d, updates, diff)

mu = update_moment(updates, state.mu, b1, 1)
delta = update_moment(diff, state.delta, b2, 1)
mu = update_moment.impl(
updates,
state.mu,
b1,
order=1,
inplace=inplace,
already_flattened=already_flattened,
)
delta = update_moment.impl(
diff,
state.delta,
b2,
1,
)
nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2)

count_inc = numerics.safe_int32_increment(state.count)
mu_hat = utils.cast_tree(bias_correction(mu, b1, count_inc), fo_dtype)
delta_hat = utils.cast_tree(bias_correction(delta, b2, count_inc), fo_dtype)
nu_hat = bias_correction(nu, b3, count_inc)
new_updates = pytree.tree_map(
lambda m, d, n: (m + b2 * d) / (torch.sqrt(n + eps_root) + eps),
mu_hat,
delta_hat,
nu_hat,
)
count_inc = inc_count.impl(updates, state.count, already_flattened=already_flattened) # type: ignore[attr-defined]
mu_hat = _adan_bias_correction(mu, b1, count_inc, already_flattened=already_flattened)
delta_hat = _adan_bias_correction(delta, b2, count_inc, already_flattened=already_flattened)
nu_hat = _adan_bias_correction(nu, b3, count_inc, already_flattened=already_flattened)

return new_updates, ScaleByAdanState(
count=count_inc,
mu=mu,
nu=nu,
delta=delta,
grad_tm1=updates,
if inplace:

def f(
m: torch.Tensor,
d: torch.Tensor,
n: torch.Tensor,
) -> torch.Tensor:
return (m + b2 * d).div_(torch.sqrt(n + eps_root).add(eps))

else:

def f(
m: torch.Tensor,
d: torch.Tensor,
n: torch.Tensor,
) -> torch.Tensor:
return (m + b2 * d).div(torch.sqrt(n + eps_root).add(eps))

# lambda m, d, n: (m + b2 * d) / (torch.sqrt(n + eps_root) + eps),
updates = pytree.tree_map(f, mu_hat, delta_hat, nu_hat)

return updates, ScaleByAdanState(
count=count_inc, mu=mu, nu=nu, delta=delta, grad_tm1=updates,
)

return base.GradientTransformation(init_fn, update_fn)
return GradientTransformation(init_fn, update_fn)


scale_by_adan.flat = _scale_by_adan_flat # type: ignore[attr-defined]
scale_by_adan.impl = _scale_by_adan # type: ignore[attr-defined]


# def scale_by_proximal_adan(
Expand Down Expand Up @@ -243,9 +301,9 @@ def update_fn(updates, state, params=None):
# nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2)

# count_inc = numerics.safe_int32_increment(state.count)
# mu_hat = bias_correction(mu, b1, count_inc)
# delta_hat = bias_correction(delta, b2, count_inc)
# nu_hat = bias_correction(nu, b3, count_inc)
# mu_hat = _adan_bias_correction(mu, b1, count_inc)
# delta_hat = _adan_bias_correction(delta, b2, count_inc)
# nu_hat = _adan_bias_correction(nu, b3, count_inc)

# if callable(learning_rate):
# lr = learning_rate(state.count)
Expand Down Expand Up @@ -276,8 +334,4 @@ def update_fn(updates, state, params=None):
# mu=mu, nu=nu, delta=delta,
# grad_tm1=updates)

# return base.GradientTransformation(init_fn, update_fn)


scale_by_adan.flat = _scale_by_adan_flat # type: ignore[attr-defined]
scale_by_adan.impl = _scale_by_adan # type: ignore[attr-defined]
# return GradientTransformation(init_fn, update_fn)