Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

sliding window self-attention cell #1395

Open
ZiyueHuang wants to merge 2 commits into dmlc:master
base: master
Choose a base branch
Loading
from ZiyueHuang:sw_atten_cell

Conversation

Copy link
Member

@ZiyueHuang ZiyueHuang commented Oct 20, 2020

Description

The AttentionCell for the sliding window self-attention, including the support for multi-headed dilation and the causal attention mode, described in Longformer: The Long-Document Transformer.

cc @sxjscience @szhengac

Checklist

Essentials

  • PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

cc @dmlc/gluon-nlp-team

szha reacted with rocket emoji
@ZiyueHuang ZiyueHuang requested a review from a team as a code owner October 20, 2020 12:20
Copy link
Member Author

Waiting for apache/mxnet#19387 to be merged.

Copy link

Copy link
Member

Is it possible for us to revise the interface to be similar to https://www.deepspeed.ai/tutorials/sparse-attention/?

Copy link

Copy link
Member Author

benchmark script


import numpy as np
from numpy.testing import assert_allclose
import mxnet as mx
from gluonnlp.attention_cell import masked_softmax, MultiHeadAttentionCell, MultiHeadSlidingWindowAttentionCell
import time
def test_multi_head_sliding_window_dot_attention_cell():
 def gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d):
 """Generate sliding_window attention mask for the full attention matrix ( seq_len^2 ).
 """
 mask_np = np.zeros((batch_size, seq_length, seq_length))
 for i in range(seq_length):
 end = (i + 1 + w * d) if symmetric else (i + 1)
 for j in range(i - w * d, end, d):
 if j >= 0 and j < seq_length:
 mask_np[:, i, j] = 1
 return mask_np
 def test_selfatten(batch_size, seq_length, num_heads, num_head_units, w, symmetric, d):
 attn_cell = MultiHeadAttentionCell()
 # Generate the data
 ctx = mx.gpu(0)
 #ctx = mx.cpu()
 query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
 key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
 value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
 mask = gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d)
 mask = mx.np.array(mask, ctx=ctx, dtype=np.float32)
 query = mx.np.array(query, ctx=ctx, dtype=np.float32)
 key = mx.np.array(key, ctx=ctx, dtype=np.float32)
 value = mx.np.array(value, ctx=ctx, dtype=np.float32)
 query.attach_grad()
 key.attach_grad()
 value.attach_grad()
 mx.npx.waitall()
 tic = time.time()
 with mx.autograd.record():
 out, _ = attn_cell(query, key, value, mask)
 out.backward()
 mx.npx.waitall()
 toc = time.time()
 return (toc - tic)
 def test_swatten(batch_size, seq_length, num_heads, num_head_units, w, symmetric, d):
 sw_attn_cell = MultiHeadSlidingWindowAttentionCell(w, symmetric)
 # Generate the data
 ctx = mx.gpu(0)
 #ctx = mx.cpu()
 query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
 key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
 value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
 query = mx.np.array(query, ctx=ctx, dtype=np.float32)
 key = mx.np.array(key, ctx=ctx, dtype=np.float32)
 value = mx.np.array(value, ctx=ctx, dtype=np.float32)
 query.attach_grad()
 key.attach_grad()
 value.attach_grad()
 dilation = mx.np.zeros((num_heads,))
 dilation[:] = d
 dilation = mx.np.array(dilation, ctx=ctx, dtype=np.int32)
 valid_length = np.zeros((batch_size,))
 valid_length[:] = seq_length
 valid_length = mx.np.array(valid_length, ctx=ctx, dtype=np.int32)
 mx.npx.waitall()
 tic = time.time()
 with mx.autograd.record():
 sw_out, _ = sw_attn_cell(query, key, value, dilation, valid_length)
 sw_out.backward()
 mx.npx.waitall()
 toc = time.time()
 return (toc - tic)
 num_repeat = 5
 for seq_length in [512, 1024, 2048, 4096]:
 dur = 0.
 w = seq_length//8
 for i in range(num_repeat):
 tmp_dur = test_selfatten(1, seq_length, 12, 64, w, True, 1)
 if i > 1:
 dur += tmp_dur
 dur /= 3.
 print('seq_length={}, w={}, time={:.3f}'.format(seq_length, w, dur))
 dur = 0.
 for i in range(num_repeat):
 tmp_dur = test_swatten(1, seq_length, 12, 64, w, True, 1)
 if i > 1:
 dur += tmp_dur
 dur /= 3.
 print('sliding-window-attention seq_length={}, w={}, time={:.3f}'.format(seq_length, w, dur))
test_multi_head_sliding_window_dot_attention_cell()

Copy link
Member

Is there any update on this PR?

Copy link
Member

szhengac commented Dec 2, 2020

@sxjscience it seems the error AttributeError: module 'mxnet.ndarray.numpy_extension' has no attribute 'sldwin_atten_score' is due to that the mxnet version is not the latest.

Copy link
Member

Yes, we can merge the master so that we will retrigger the test.

Copy link
Member

Do we have update on this? @ZiyueHuang would you have time to rebase the code?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Reviewers

2 more reviewers

@szhengac szhengac szhengac left review comments

@sxjscience sxjscience sxjscience left review comments

Reviewers whose approvals may not affect merge requirements

At least 1 approving review is required to merge this pull request.

Assignees

No one assigned

Labels

None yet

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

AltStyle によって変換されたページ (->オリジナル) /