-
Notifications
You must be signed in to change notification settings - Fork 529
sliding window self-attention cell #1395
Conversation
Waiting for apache/mxnet#19387 to be merged.
The documentation website for preview: http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR1395/sw_atten_cell/index.html
Is it possible for us to revise the interface to be similar to https://www.deepspeed.ai/tutorials/sparse-attention/?
The documentation website for preview: http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR1395/sw_atten_cell/index.html
benchmark script
import numpy as np
from numpy.testing import assert_allclose
import mxnet as mx
from gluonnlp.attention_cell import masked_softmax, MultiHeadAttentionCell, MultiHeadSlidingWindowAttentionCell
import time
def test_multi_head_sliding_window_dot_attention_cell():
def gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d):
"""Generate sliding_window attention mask for the full attention matrix ( seq_len^2 ).
"""
mask_np = np.zeros((batch_size, seq_length, seq_length))
for i in range(seq_length):
end = (i + 1 + w * d) if symmetric else (i + 1)
for j in range(i - w * d, end, d):
if j >= 0 and j < seq_length:
mask_np[:, i, j] = 1
return mask_np
def test_selfatten(batch_size, seq_length, num_heads, num_head_units, w, symmetric, d):
attn_cell = MultiHeadAttentionCell()
# Generate the data
ctx = mx.gpu(0)
#ctx = mx.cpu()
query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
mask = gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d)
mask = mx.np.array(mask, ctx=ctx, dtype=np.float32)
query = mx.np.array(query, ctx=ctx, dtype=np.float32)
key = mx.np.array(key, ctx=ctx, dtype=np.float32)
value = mx.np.array(value, ctx=ctx, dtype=np.float32)
query.attach_grad()
key.attach_grad()
value.attach_grad()
mx.npx.waitall()
tic = time.time()
with mx.autograd.record():
out, _ = attn_cell(query, key, value, mask)
out.backward()
mx.npx.waitall()
toc = time.time()
return (toc - tic)
def test_swatten(batch_size, seq_length, num_heads, num_head_units, w, symmetric, d):
sw_attn_cell = MultiHeadSlidingWindowAttentionCell(w, symmetric)
# Generate the data
ctx = mx.gpu(0)
#ctx = mx.cpu()
query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
query = mx.np.array(query, ctx=ctx, dtype=np.float32)
key = mx.np.array(key, ctx=ctx, dtype=np.float32)
value = mx.np.array(value, ctx=ctx, dtype=np.float32)
query.attach_grad()
key.attach_grad()
value.attach_grad()
dilation = mx.np.zeros((num_heads,))
dilation[:] = d
dilation = mx.np.array(dilation, ctx=ctx, dtype=np.int32)
valid_length = np.zeros((batch_size,))
valid_length[:] = seq_length
valid_length = mx.np.array(valid_length, ctx=ctx, dtype=np.int32)
mx.npx.waitall()
tic = time.time()
with mx.autograd.record():
sw_out, _ = sw_attn_cell(query, key, value, dilation, valid_length)
sw_out.backward()
mx.npx.waitall()
toc = time.time()
return (toc - tic)
num_repeat = 5
for seq_length in [512, 1024, 2048, 4096]:
dur = 0.
w = seq_length//8
for i in range(num_repeat):
tmp_dur = test_selfatten(1, seq_length, 12, 64, w, True, 1)
if i > 1:
dur += tmp_dur
dur /= 3.
print('seq_length={}, w={}, time={:.3f}'.format(seq_length, w, dur))
dur = 0.
for i in range(num_repeat):
tmp_dur = test_swatten(1, seq_length, 12, 64, w, True, 1)
if i > 1:
dur += tmp_dur
dur /= 3.
print('sliding-window-attention seq_length={}, w={}, time={:.3f}'.format(seq_length, w, dur))
test_multi_head_sliding_window_dot_attention_cell()
Is there any update on this PR?
@sxjscience it seems the error AttributeError: module 'mxnet.ndarray.numpy_extension' has no attribute 'sldwin_atten_score'
is due to that the mxnet version is not the latest.
Yes, we can merge the master so that we will retrigger the test.
Do we have update on this? @ZiyueHuang would you have time to rebase the code?
Description
The AttentionCell for the sliding window self-attention, including the support for multi-headed dilation and the causal attention mode, described in Longformer: The Long-Document Transformer.
cc @sxjscience @szhengac
Checklist
Essentials
Changes
Comments
cc @dmlc/gluon-nlp-team