Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,26 @@
)
else:
if attn_mask_startend_row_indices is not None:
assert alibi is None, "flash_attention_with_sparse_mask not support alibi"
assert alibi is None, "flashmask_attention or flash_attention_with_sparse_mask not support alibi"

Check warning on line 217 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L217

Added line #L217 was not covered by tests
if len(attn_mask_startend_row_indices.shape) == 2:
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1)
attn_output = F.flash_attention_with_sparse_mask(
query_states,
key_states,
value_states,
attn_mask_start_row_indices=attn_mask_startend_row_indices,
is_causal=True,
)

if hasattr(F, "flashmask_attention"):
attn_output = F.flashmask_attention(

Check warning on line 222 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L221-L222

Added lines #L221 - L222 were not covered by tests
query_states,
key_states,
value_states,
startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1),
causal=True,
)
else:
attn_output = F.flash_attention_with_sparse_mask(

Check warning on line 230 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L230

Added line #L230 was not covered by tests
query_states,
key_states,
value_states,
attn_mask_start_row_indices=attn_mask_startend_row_indices,
is_causal=True,
)
else:
attn_output = F.scaled_dot_product_attention(
query_states,
Expand Down
Loading