Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add SwiGLU for auto Llama
  • Loading branch information
From00 committed Mar 1, 2024
commit 1b113be6308dc166c643cdb48bda03ceb8d2d3fe
16 changes: 13 additions & 3 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
except ImportError:
fused_rotary_position_embedding = None

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:

def swiglu(x, y=None):
if y is None:
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y


from paddlenlp.transformers.conversion_utils import (
StateDictNameMapping,
init_name_mappings,
Expand Down Expand Up @@ -228,10 +238,10 @@ def __init__(self, config, ipp: Optional[int] = None):

def forward(self, x):
if self.fuse_attention_ffn:
gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1)
out = self.down_proj(F.silu(gate_out) * up_out)
x = swiglu(self.gate_up_fused_proj(x))
else:
out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
x = swiglu(self.gate_proj(x), self.up_proj(x))
out = self.down_proj(x)
return out


Expand Down
16 changes: 13 additions & 3 deletions paddlenlp/transformers/llama/modeling_auto_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@
except ImportError:
fused_rotary_position_embedding = None

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:

def swiglu(x, y=None):
if y is None:
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y


from paddlenlp.transformers.conversion_utils import (
StateDictNameMapping,
init_name_mappings,
Expand Down Expand Up @@ -242,10 +252,10 @@ def forward(self, x):
fleet.auto.shard_tensor(self.down_proj.weight, *get_dist_attr(["mp", None], self.ipp))

if self.fuse_attention_ffn:
gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1)
out = self.down_proj(F.silu(gate_out) * up_out)
x = swiglu(self.gate_up_fused_proj(x))
else:
out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
x = swiglu(self.gate_proj(x), self.up_proj(x))
out = self.down_proj(x)
return out


Expand Down