forked from zxjzxj9/PyTorchIntroduction
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhalf_prec.py
More file actions
22 lines (19 loc) · 732 Bytes
/
half_prec.py
File metadata and controls
22 lines (19 loc) · 732 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
""" 本代码仅供半精度模型训练的饿参考
"""
from apex.fp16_utils import *
from apex import amp, optimizers
model = Model()
model = model.cuda()
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale)
# ...
loss = criterion(output, target)
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()