|
1 | 1 | # Copyright 2024 Google Inc. |
2 | | -# |
3 | 2 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 3 | # you may not use this file except in compliance with the License. |
5 | 4 | # You may obtain a copy of the License at |
@@ -59,6 +58,7 @@ class TrainState(flax.struct.PyTreeNode): # pytype: disable=invalid-function-de |
59 | 58 | opt_state: optax.OptState |
60 | 59 | params: flax.core.FrozenDict[str, Any] |
61 | 60 | batch_stats: Any |
| 61 | + ema_params: Any = None |
62 | 62 |
|
63 | 63 |
|
64 | 64 | DataIterator = TypeVar( |
@@ -99,6 +99,7 @@ def create_train_state( |
99 | 99 | opt_state=tx.init(params), |
100 | 100 | batch_stats=variables.get('batch_stats', None), |
101 | 101 | params=params, |
| 102 | + ema_params=params if config.get('ema_decay', 0.0) > 0.0 else None, |
102 | 103 | ), |
103 | 104 | ) |
104 | 105 |
|
@@ -221,11 +222,24 @@ def loss_fn(params): |
221 | 222 | (state.params, state.opt_state), |
222 | 223 | ) |
223 | 224 |
|
| 225 | + ema_decay = config.get('ema_decay', 0.0) |
| 226 | + new_ema_params = state.ema_params |
| 227 | + if ema_decay > 0.0: |
| 228 | + if new_ema_params is None: |
| 229 | + new_ema_params = new_params |
| 230 | + else: |
| 231 | + decay = jnp.array(ema_decay, dtype=jnp.float32) |
| 232 | + decay = jnp.where(state.step == 0, 0.0, decay) |
| 233 | + new_ema_params = optax.incremental_update( |
| 234 | + new_params, new_ema_params, step_size=1.0 - decay |
| 235 | + ) |
| 236 | + |
224 | 237 | new_state = state.replace( # pytype: disable=attribute-error |
225 | 238 | step=step, |
226 | 239 | params=new_params, |
227 | 240 | opt_state=new_opt_state, |
228 | 241 | batch_stats=new_batch_stats, |
| 242 | + ema_params=new_ema_params, |
229 | 243 | ) |
230 | 244 |
|
231 | 245 | lr = schedule(state.opt_state.count) # pytype: disable=attribute-error |
|
0 commit comments