MegEngine 能否实现全局平均池化

请问一下megengine能否实现类似pytorch nn.AdaptiveAvgPool2d(1) 的方法
例如:

import megengine.module as M 
import megengine as mge
import numpy as np
pool = M.pooling.AvgPool2d((6, 8), 1, 0)
x = mge.tensor(np.random.rand(3, 3, 6, 8).astype('float32'))
ret = pool(x) # ret.shape (3, 3, 1, 1)

但是pooling 的 size和输入特征图强相关的,如果特征图更换一个尺度就不能够使用了。
是否应该通过F.mean()实现呢?

out = F.mean((F.mean(x, 2, keepdims=True)), 3, keepdims=True)

但是在判断的时候

ret[:, :, 0, 0] == out[:, :, 0, 0]
#Tensor([[1. 1. 1.]
# [0. 0. 1.]
# [0. 0. 0.]])
#

不确定是浮点数精度还是我实现的方法有问题,
请问F.mean()实现的方法是否和pool2d计算方式一致,能否以这种方式实现每个通道的平均值。

测试浮点数是否一样,不会用 ==
试试这个
import numpy as np
np.testing.assert_allclose(ret.numpy(), out.numpy(), 6)

1赞

我试了一下,ret - out的输出均为0

>>> ret[:, :, 0,0] - out[:, :, 0, 0]
Tensor([[ 0. -0.  0.]
 [-0.  0.  0.]
 [-0. -0.  0.]])

global pooling可以用mean来实现,目前还不支持adaptive pooling,你的写法没有问题,也可以用下面的写法

>>> y = x.mean(axis=3).mean(axis=2)
1赞

@wangjingyi @felixfan 谢谢,我发现了新的问题,每次前传显存都会累积。
代码如下,首先通过全局池化写Channel Attention Layer。

N, C, H, W -> N, C, 1, 1 -> N, C//4, 1, 1 -> N, C, 1, 1 * N, C, H, W 

class CALayer(M.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()

        self.conv_du = M.Sequential(
            M.Conv2d(channel, channel // reduction, kernel_size=1, padding=0, bias=True),
            M.ReLU(),
            M.Conv2d(channel // reduction, channel, kernel_size=1, padding=0, bias=True),
            M.Sigmoid()
        )

    def forward(self, inputs):
        # B C H W -> B C 1 1
#         y = F.mean(F.mean(inputs, 2, keepdims=True), 3, keepdims=True)
        y = x.mean(axis=3, keepdims=True).mean(axis=2, keepdims=True)
        # B, C, H, W = inputs.shape
        # y = F.avg_pool2d(inputs, (H, W))
        y = self.conv_du(y)
        return inputs * y

构建模块

import megengine as mge
import numpy as np
ca = CALayer(64)

推理:

for i in range(100):
    x = mge.tensor(np.random.randn(8, 64, 64, 64).astype(np.float32))
    out = ca(x)
print(out.shape)  # 输出: (8, 64, 64, 64)

循环过过程显存一再涨。
MegEngine写attention的正确姿势是什么…