Hi,
我有一个关于自动混合精度AMP的问题想请教一下。
假设我有一个model
Input -> conv2d_1 -> relu_1 -> conv2d_2 -> relu_2 -> conv2d_3 -> relu_3 -> conv2d_4 -> relu4 -> output
我是否可以通过一些接口来指定每个op用float16还是fp32. 比如 conv2d_1, conv2d_3 用bf16, conv2d_2, conv2d_4 用fp32.
期待大家的恢复, 谢谢!
Hi,
我有一个关于自动混合精度AMP的问题想请教一下。
假设我有一个model
Input -> conv2d_1 -> relu_1 -> conv2d_2 -> relu_2 -> conv2d_3 -> relu_3 -> conv2d_4 -> relu4 -> output
我是否可以通过一些接口来指定每个op用float16还是fp32. 比如 conv2d_1, conv2d_3 用bf16, conv2d_2, conv2d_4 用fp32.
期待大家的恢复, 谢谢!
是可以的您可以对想转换的改变下low_prec_dtype跟high_prec_dtype
具体在哪里设置可以简单说明一下吗
在接口设计上由于我们没有类似白名单一样的操作,所以只能手动去指定部分模块的精度类型
举个例子:
from megengine import amp
class Module(M.Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn = M.BatchNorm2d(16)
def forward(self, x):
with amp.autocast(low_prec_dtype="float32"):
conv1 = self.conv(x)
out = F.relu(self.bn(conv1))
return out
@amp.autocast(enabled=True)
def train_step(image):
with gm:
logits = m(image)
.....
opt.step().clear_grad()
return logits
在您需要保留float32的地方去挂一个context-manager,将其low_dtype设置为float32
您看下这样是否可以解决您的问题那?