版本信息:
MegEngine版本:1.13.1
python版本:3.8
运行如下代码可以正常运行:
import megengine as mge
import numpy as np
import megengine.functional as F
embed_dim, num_heads = 768, 12
seq_length, batch_size = 1024, 2
query = F.ones((batch_size, seq_length, embed_dim)) * 1.0
multihead_attn = mge.module.MultiHeadAttention(embed_dim, num_heads, bias=True)
attn_output = multihead_attn(query, query, query)[0]
print(attn_output)
但是如果把bias设置成False,就会有如下报错:
请问是怎么回事?