由于我的模型是两个输入的,会得到两个feature map,在中间对两个feature map需要做互相关操作,使用F.conv2d实现。 在trace后转为onnx时,会使得一个分支丢失,而如果只提取特征的话,即不使用F.conv2d则不会有这个问题。
测试代码如下
import megengine as mge
import megengine.module as M
import numpy as np
import megengine.functional as F
def xcorr_depthwise(x, kernel):
"""
x: [B,C,H,W]
kernel: [B,C,h,w]
"""
batch = int(kernel.shape[0])
channel = int(kernel.shape[1])
bc = batch*channel
x = x.reshape((1, bc, int(x.shape[2]), int(x.shape[3])))
kernel = kernel.reshape(bc, 1, 1, int(kernel.shape[2]), int(kernel.shape[3]))
out = F.conv2d(x, kernel, groups=bc)
out = out.reshape(batch, channel, int(out.shape[2]), int(out.shape[3]))
return out
class AlexNet_stride8(M.Module):
def __init__(self, in_cha, ch=48):
super(AlexNet_stride8, self).__init__()
assert ch % 2 ==0, "channel nums should % 2 = 0"
self.conv1 = M.conv_bn.ConvBnRelu2d(in_cha, ch//2, kernel_size=11, stride=2, padding=5)
self.pool1 = M.MaxPool2d(3, 2, 1)
self.conv2 = M.conv_bn.ConvBnRelu2d(ch//2, ch, 5, 1, 2)
self.pool2 = M.MaxPool2d(3, 2, 1)
self.conv3 = M.conv_bn.ConvBnRelu2d(ch, ch*2, 3, 1, 1)
self.conv4 = M.conv_bn.ConvBnRelu2d(ch*2, ch*2, 3, 1, 1)
self.conv5 = M.conv_bn.ConvBn2d(ch*2, ch, 3, 1, 1)
def forward(self, x):
# 800, 512
x = self.conv1(x) # 400, 256
x = self.pool1(x) # 200, 128
x = self.conv2(x) # 200, 128
x = self.pool2(x) # 100, 64
x = self.conv3(x) # 100, 64
x = self.conv4(x) # 100, 64
x = self.conv5(x) # 100, 64
return x
class AlexNet_stride4(M.Module):
def __init__(self, in_cha, ch=48):
super(AlexNet_stride4, self).__init__()
assert ch % 2 ==0, "channel nums should % 2 = 0"
self.conv1 = M.conv_bn.ConvBnRelu2d(in_cha, ch//2, kernel_size=11, stride=2, padding=5) # 2*1*121*400*400*24 929280000
self.conv2 = M.conv_bn.ConvBnRelu2d(ch//2, ch, 5, 1, 2) # 2*24*25*400*400*48 9216000000
self.pool1 = M.MaxPool2d(3, 2, 1)
self.conv3 = M.conv_bn.ConvBnRelu2d(ch, ch, 3, 1, 1) # 2*48*9*200*200*48 1658880000
self.conv4 = M.conv_bn.ConvBnRelu2d(ch, ch, 3, 1, 1) # 2*48*9*200*200*48 1658880000
self.conv5 = M.conv_bn.ConvBn2d(ch, ch, 3, 1, 1) # 2*48*9*200*200*48 1658880000
def forward(self, x):
x = self.conv1(x) # 400, 256
x = self.conv2(x) # 400, 256
x = self.pool1(x) # 200, 128
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x) # 200, 128
return x
class Net(M.Module):
def __init__(self, backbone_1, backbone_2):
super(Net, self).__init__()
self.backbone_1 = backbone_1
self.backbone_2 = backbone_2
def forward(self, input1, input2):
feat1 = self.backbone_1(input1)
feat2 = self.backbone_2(input2)
r_out = xcorr_depthwise(feat2, feat1) # [37, 37]
c_out = xcorr_depthwise(feat2, feat1)
return r_out,c_out
if __name__ == "__main__":
from megengine.jit import trace
alex_net1 = AlexNet_stride8(in_cha=1)
alex_net2 = AlexNet_stride8(in_cha=1)
net = Net(alex_net1, alex_net2)
net.eval()
@trace(symbolic=True, capture_as_const=True)
def inference_func(a, b, *, model):
score, offset = model(a,b)
return score, offset
a = mge.tensor(np.random.random([1, 1, 512, 512]).astype(np.float32))
b = mge.tensor(np.random.random([1, 1, 800, 800]).astype(np.float32))
inference_func(a,b, model=net)
inference_func.dump("test_corr.mge", arg_names=["input1","input2"], output_names=["corr1", "corr2"])
使用mgeconvert转换,命令如下
convert onnx -i test_corr.mge -o test_corr.onnx
正常转换结果
出错结果如下