Skip to content

Commit 2d5258a

Browse files
N8pythonawni
andauthored
Adds EXAONE architecture. (#1145)
* Adds EXAONE architecture. * nits + format * format * clean up and fix rope * clean up and fix rope --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent e98f427 commit 2d5258a

File tree

6 files changed

+312
-224
lines changed

6 files changed

+312
-224
lines changed

mlx_lm/models/exaone.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright © 2024 Apple Inc.
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Dict, Optional, Union
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
9+
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
10+
from .rope_utils import initialize_rope
11+
12+
13+
@dataclass
14+
class ModelArgs(BaseModelArgs):
15+
model_type: str
16+
hidden_size: int
17+
num_layers: int
18+
intermediate_size: int
19+
num_attention_heads: int
20+
vocab_size: int
21+
rope_theta: float
22+
layer_norm_epsilon: float
23+
num_key_value_heads: int
24+
head_dim: Optional[int] = None
25+
max_position_embeddings: Optional[int] = None
26+
rope_traditional: bool = False
27+
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
28+
tie_word_embeddings: bool = True
29+
attention_bias: bool = False
30+
mlp_bias: bool = False
31+
32+
33+
class AttentionModule(nn.Module):
34+
def __init__(self, args: ModelArgs):
35+
super().__init__()
36+
dim = args.hidden_size
37+
self.n_heads = n_heads = args.num_attention_heads
38+
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
39+
self.head_dim = head_dim = args.head_dim or (dim // n_heads)
40+
self.scale = head_dim**-0.5
41+
42+
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
43+
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
44+
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
45+
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
46+
47+
self.rope = initialize_rope(
48+
self.head_dim,
49+
args.rope_theta,
50+
args.rope_traditional,
51+
args.rope_scaling,
52+
args.max_position_embeddings,
53+
)
54+
55+
def __call__(
56+
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None
57+
) -> mx.array:
58+
B, L, D = x.shape
59+
q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
60+
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
61+
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
62+
63+
if cache is not None:
64+
q = self.rope(q, offset=cache.offset)
65+
k = self.rope(k, offset=cache.offset)
66+
k, v = cache.update_and_fetch(k, v)
67+
else:
68+
q = self.rope(q)
69+
k = self.rope(k)
70+
71+
out = scaled_dot_product_attention(
72+
q, k, v, cache=cache, scale=self.scale, mask=mask
73+
)
74+
out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
75+
return self.out_proj(out)
76+
77+
78+
class Attention(nn.Module):
79+
def __init__(self, args: ModelArgs):
80+
super().__init__()
81+
self.attention = AttentionModule(args)
82+
83+
84+
class MLP(nn.Module):
85+
def __init__(self, args: ModelArgs):
86+
super().__init__()
87+
dim = args.hidden_size
88+
hidden_dim = args.intermediate_size
89+
self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
90+
self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
91+
self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias)
92+
93+
def __call__(self, x: mx.array) -> mx.array:
94+
return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x))
95+
96+
97+
class TransformerBlock(nn.Module):
98+
def __init__(self, args: ModelArgs):
99+
super().__init__()
100+
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
101+
self.attn = Attention(args)
102+
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
103+
self.mlp = MLP(args)
104+
105+
def __call__(
106+
self,
107+
x: mx.array,
108+
mask: Optional[mx.array] = None,
109+
cache: Optional[Any] = None,
110+
) -> mx.array:
111+
h = x + self.attn.attention(self.ln_1(x), mask, cache)
112+
out = h + self.mlp(self.ln_2(h))
113+
return out
114+
115+
116+
class ExaoneModel(nn.Module):
117+
def __init__(self, args: ModelArgs):
118+
super().__init__()
119+
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
120+
self.h = [TransformerBlock(args) for _ in range(args.num_layers)]
121+
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
122+
123+
def __call__(
124+
self,
125+
inputs: mx.array,
126+
cache=None,
127+
):
128+
h = self.wte(inputs)
129+
mask = create_attention_mask(h, cache)
130+
131+
if cache is None:
132+
cache = [None] * len(self.h)
133+
134+
for layer, c in zip(self.h, cache):
135+
h = layer(h, mask, cache=c)
136+
137+
return self.ln_f(h)
138+
139+
140+
class Model(nn.Module):
141+
def __init__(self, args: ModelArgs):
142+
super().__init__()
143+
self.args = args
144+
self.model_type = args.model_type
145+
self.transformer = ExaoneModel(args)
146+
if not args.tie_word_embeddings:
147+
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
148+
149+
def __call__(
150+
self,
151+
inputs: mx.array,
152+
cache=None,
153+
):
154+
out = self.transformer(inputs, cache)
155+
if self.args.tie_word_embeddings:
156+
out = self.transformer.wte.as_linear(out)
157+
else:
158+
out = self.lm_head(out)
159+
return out
160+
161+
@property
162+
def layers(self):
163+
return self.transformer.h

mlx_lm/models/llama.py

Lines changed: 8 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import mlx.nn as nn
88

99
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
10+
from .rope_utils import initialize_rope
1011

1112

1213
@dataclass
@@ -32,117 +33,6 @@ def __post_init__(self):
3233
if self.num_key_value_heads is None:
3334
self.num_key_value_heads = self.num_attention_heads
3435

35-
if self.rope_scaling:
36-
if not "factor" in self.rope_scaling:
37-
raise ValueError(f"rope_scaling must contain 'factor'")
38-
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
39-
"rope_type"
40-
)
41-
if rope_type is None:
42-
raise ValueError(
43-
f"rope_scaling must contain either 'type' or 'rope_type'"
44-
)
45-
if rope_type not in ["linear", "dynamic", "llama3"]:
46-
raise ValueError(
47-
"rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
48-
)
49-
50-
51-
class DynamicNTKScalingRoPE(nn.Module):
52-
"""Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
53-
54-
def __init__(
55-
self,
56-
dims: int,
57-
max_position_embeddings: int = 2048,
58-
traditional: bool = False,
59-
base: float = 10000,
60-
scale: float = 1.0,
61-
rope_type: str = "default",
62-
rope_scaling: dict = None,
63-
):
64-
super().__init__()
65-
self.dims = dims
66-
self.max_position_embeddings = max_position_embeddings
67-
self.traditional = traditional
68-
self.scale = scale
69-
self.rope_type = rope_type
70-
self.rope_scaling = rope_scaling
71-
self.base = base
72-
self.compute_freqs()
73-
74-
def compute_freqs(self):
75-
if self.rope_type != "llama3":
76-
self._freqs = None
77-
return
78-
factor = self.rope_scaling["factor"]
79-
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
80-
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
81-
old_context_len = self.rope_scaling.get(
82-
"original_max_position_embeddings",
83-
8192,
84-
)
85-
86-
low_freq_wavelen = old_context_len / low_freq_factor
87-
high_freq_wavelen = old_context_len / high_freq_factor
88-
89-
freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
90-
wavelens = 2 * mx.pi * freqs
91-
92-
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
93-
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
94-
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
95-
high_freq_factor - low_freq_factor
96-
)
97-
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
98-
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
99-
self.base = None
100-
101-
def extra_repr(self):
102-
return (
103-
f"{self.dims}, traditional={self.traditional}, "
104-
f"max_position_embeddings={self.max_position_embeddings}, "
105-
f"scaling_factor={self.scale}, rope_type={self.rope_type}"
106-
)
107-
108-
def __call__(self, x, offset: int = 0):
109-
return mx.fast.rope(
110-
x,
111-
self.dims,
112-
traditional=self.traditional,
113-
base=self.base,
114-
scale=self.scale,
115-
offset=offset,
116-
freqs=self._freqs,
117-
)
118-
119-
120-
def initialize_rope(args: ModelArgs):
121-
head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
122-
123-
rope_scaling = args.rope_scaling
124-
rope_type = "default"
125-
rope_scale = 1.0
126-
127-
if rope_scaling is not None:
128-
rope_type = (
129-
rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
130-
)
131-
if rope_type == "linear":
132-
rope_scale = 1 / rope_scaling["factor"]
133-
elif rope_type == "llama3":
134-
rope_scale = 1.0 # The scaling is handled internally for llama3
135-
136-
return DynamicNTKScalingRoPE(
137-
dims=head_dim,
138-
max_position_embeddings=args.max_position_embeddings,
139-
traditional=args.rope_traditional,
140-
base=args.rope_theta,
141-
scale=rope_scale,
142-
rope_type=rope_type,
143-
rope_scaling=rope_scaling,
144-
)
145-
14636

14737
class Attention(nn.Module):
14838
def __init__(self, args: ModelArgs):
@@ -165,7 +55,13 @@ def __init__(self, args: ModelArgs):
16555
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
16656
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
16757

168-
self.rope = initialize_rope(args)
58+
self.rope = initialize_rope(
59+
self.head_dim,
60+
args.rope_theta,
61+
args.rope_traditional,
62+
args.rope_scaling,
63+
args.max_position_embeddings,
64+
)
16965

17066
def __call__(
17167
self,

0 commit comments

Comments
 (0)