Skip to content

Commit dccf620

Browse files
mjanuszcopybara-github
authored andcommitted
Add the option to track EMA weights in FFN training.
PiperOrigin-RevId: 875053651
1 parent 8f876db commit dccf620

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

ffn/jax/train.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024 Google Inc.
2-
#
32
# Licensed under the Apache License, Version 2.0 (the "License");
43
# you may not use this file except in compliance with the License.
54
# You may obtain a copy of the License at
@@ -59,6 +58,7 @@ class TrainState(flax.struct.PyTreeNode): # pytype: disable=invalid-function-de
5958
opt_state: optax.OptState
6059
params: flax.core.FrozenDict[str, Any]
6160
batch_stats: Any
61+
ema_params: Any = None
6262

6363

6464
DataIterator = TypeVar(
@@ -99,6 +99,7 @@ def create_train_state(
9999
opt_state=tx.init(params),
100100
batch_stats=variables.get('batch_stats', None),
101101
params=params,
102+
ema_params=params if config.get('ema_decay', 0.0) > 0.0 else None,
102103
),
103104
)
104105

@@ -221,11 +222,24 @@ def loss_fn(params):
221222
(state.params, state.opt_state),
222223
)
223224

225+
ema_decay = config.get('ema_decay', 0.0)
226+
new_ema_params = state.ema_params
227+
if ema_decay > 0.0:
228+
if new_ema_params is None:
229+
new_ema_params = new_params
230+
else:
231+
decay = jnp.array(ema_decay, dtype=jnp.float32)
232+
decay = jnp.where(state.step == 0, 0.0, decay)
233+
new_ema_params = optax.incremental_update(
234+
new_params, new_ema_params, step_size=1.0 - decay
235+
)
236+
224237
new_state = state.replace( # pytype: disable=attribute-error
225238
step=step,
226239
params=new_params,
227240
opt_state=new_opt_state,
228241
batch_stats=new_batch_stats,
242+
ema_params=new_ema_params,
229243
)
230244

231245
lr = schedule(state.opt_state.count) # pytype: disable=attribute-error

0 commit comments

Comments
 (0)