是否可以为每个op指定精度?

Hi,
我有一个关于自动混合精度AMP的问题想请教一下。

假设我有一个model

Input -> conv2d_1 -> relu_1 -> conv2d_2 -> relu_2 -> conv2d_3 -> relu_3 -> conv2d_4 -> relu4 -> output

我是否可以通过一些接口来指定每个op用float16还是fp32. 比如 conv2d_1, conv2d_3 用bf16, conv2d_2, conv2d_4 用fp32.

期待大家的恢复, 谢谢!

是可以的您可以对想转换的改变下low_prec_dtype跟high_prec_dtype

具体在哪里设置可以简单说明一下吗

在接口设计上由于我们没有类似白名单一样的操作,所以只能手动去指定部分模块的精度类型

举个例子:

from megengine import amp

class Module(M.Module):
    def __init__(self):
        super().__init__()
        self.conv = M.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = M.BatchNorm2d(16)
    def forward(self, x):
        with amp.autocast(low_prec_dtype="float32"):
            conv1 = self.conv(x)
        out = F.relu(self.bn(conv1))
        return out

@amp.autocast(enabled=True)
def train_step(image):
    with gm:
        logits = m(image)
        .....
    opt.step().clear_grad()
    return logits

在您需要保留float32的地方去挂一个context-manager,将其low_dtype设置为float32
您看下这样是否可以解决您的问题那?

1赞