Skip to content

Commit b6ab352

Browse files
iftenneyLIT team
authored andcommitted
Support different float precision for LM salience.
Also set watch_accessed_variables=False, because we don't need it. PiperOrigin-RevId: 607091706
1 parent 1df3ba8 commit b6ab352

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

lit_nlp/examples/lm_salience_demo.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
from collections.abc import Sequence
44
import functools
5+
import os
56
import sys
67
from typing import Optional
78

89
from absl import app
910
from absl import flags
1011
from absl import logging
12+
import keras
1113
from lit_nlp import dev_server
1214
from lit_nlp import server_flags
1315
from lit_nlp.api import layout
@@ -37,6 +39,10 @@
3739
),
3840
)
3941

42+
_KERAS_FLOATX = flags.DEFINE_string(
43+
"keras_floatx", "bfloat16", "Floating-point type for Keras models."
44+
)
45+
4046
# Custom frontend layout; see api/layout.py
4147
modules = layout.LitModuleName
4248
LM_LAYOUT = layout.LitCanonicalLayout(
@@ -109,6 +115,10 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
109115
if len(argv) > 1:
110116
raise app.UsageError("Too many command-line arguments.")
111117

118+
# Set Keras backend and floating-point precision.
119+
os.environ["KERAS_BACKEND"] = "tensorflow"
120+
keras.config.set_floatx(_KERAS_FLOATX.value)
121+
112122
plaintextPrompts = functools.partial( # pylint: disable=invalid-name
113123
lm_data.PlaintextSents, field_name="prompt"
114124
)

lit_nlp/examples/models/instrumented_keras_lms.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def __init__(
6868
self.model.preprocessor.tokenizer.id_to_token
6969
)
7070

71-
# map ids: <tf.int64>[batch_size, num_tokens]
72-
# to embs: <tf.float32>[batch_size, num_tokens, emb_dim]
71+
# map ids: <tf.int>[batch_size, num_tokens]
72+
# to embs: <tf.float>[batch_size, num_tokens, emb_dim]
7373
self.embedder = self.model.backbone.token_embedding
7474

7575
@classmethod
@@ -114,7 +114,7 @@ def embed_texts(self, texts: Sequence[str]):
114114
processed_inputs = self.encode_inputs(
115115
texts, sequence_length=self.max_length
116116
)
117-
# <tf.float32>[batch_size, num_tokens, emb_dim]
117+
# <tf.float>[batch_size, num_tokens, emb_dim]
118118
embs = self.embedder(processed_inputs["token_ids"])
119119
# <tf.bool>[batch_size, num_tokens]
120120
mask = processed_inputs["padding_mask"]
@@ -123,13 +123,13 @@ def embed_texts(self, texts: Sequence[str]):
123123
def embed_and_mean_pool(self, texts: Sequence[str]):
124124
"""Return a single vector for each text."""
125125
embs, mask = self.embed_texts(texts)
126-
# <tf.float32>[batch_size, num_tokens, 1]
127-
mask = tf.expand_dims(tf.cast(mask, dtype=tf.float32), axis=2)
128-
# <tf.float32>[batch_size, 1, emb_dim]
126+
# <tf.float>[batch_size, num_tokens, 1]
127+
mask = tf.expand_dims(tf.cast(mask, dtype=embs.dtype), axis=2)
128+
# <tf.float>[batch_size, 1, emb_dim]
129129
pooled_embs = tf.reduce_sum(
130130
mask * embs, axis=1, keepdims=True
131131
) / tf.reduce_sum(mask, axis=1, keepdims=True)
132-
# <tf.float32>[batch_size, emb_dim]
132+
# <tf.float>[batch_size, emb_dim]
133133
return tf.squeeze(pooled_embs, axis=1)
134134

135135
def predict_minibatch(
@@ -203,7 +203,7 @@ def __init__(self, *args, **kw):
203203

204204
def _pred(self, input_ids, padding_mask, target_masks):
205205
"""Predict a batch of tokenized text."""
206-
# <tf.float32>[batch_size, num_tokens]; ignore the last one in each row.
206+
# <tf.int>[batch_size, num_tokens]; ignore the last one in each row.
207207
target_ids = tf.roll(input_ids, shift=-1, axis=1)
208208

209209
##
@@ -226,13 +226,13 @@ def _pred(self, input_ids, padding_mask, target_masks):
226226
axis=0,
227227
)
228228

229-
padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32)
229+
padded_target_masks = tf.constant(padded_target_masks, dtype=tf.bool)
230230
# Shift masks back so they align with target_ids.
231231
loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1)
232232

233233
embeddings = None
234234

235-
with tf.GradientTape(watch_accessed_variables=True) as tape:
235+
with tf.GradientTape(watch_accessed_variables=False) as tape:
236236

237237
def layer_intercept_fn(x, i):
238238
if i == -1:
@@ -241,21 +241,21 @@ def layer_intercept_fn(x, i):
241241
tape.watch(embeddings)
242242
return x
243243

244-
# <tf.float32>[batch_size, num_tokens]
244+
# <tf.float>[batch_size, num_tokens]
245245
per_token_loss = self.model.score(
246246
token_ids=input_ids,
247247
padding_mask=padding_mask,
248248
scoring_mode="loss",
249249
layer_intercept_fn=layer_intercept_fn,
250250
target_ids=target_ids,
251251
)
252-
masked_loss = per_token_loss * loss_mask
252+
masked_loss = per_token_loss * tf.cast(loss_mask, per_token_loss.dtype)
253253

254-
# <tf.float32>[batch_size, num_tokens, hdim]
254+
# <tf.float>[batch_size, num_tokens, hdim]
255255
grads = tape.gradient(masked_loss, embeddings)
256-
# <tf.float32>[batch_size, num_tokens]
256+
# <tf.float>[batch_size, num_tokens]
257257
grad_l2 = tf.norm(grads, axis=2)
258-
# <tf.float32>[batch_size, num_tokens]
258+
# <tf.float>[batch_size, num_tokens]
259259
grad_dot_input = tf.reduce_sum(grads * embeddings, axis=2)
260260

261261
batched_outputs = {

lit_nlp/examples/models/pretrained_lms.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def predict_minibatch(self, inputs):
490490
responses = self.tokenizer.batch_decode(
491491
outputs[:, -self.max_new_tokens :], skip_special_tokens=True
492492
)
493-
# Input embeddings: <tf.float32>[batch_size, num_tokens, emb_dim]
493+
# Input embeddings: <tf.float>[batch_size, num_tokens, emb_dim]
494494
embeddings = self.model.transformer.wte(outputs)
495495
batched_outputs = {
496496
"embs": embeddings,
@@ -532,7 +532,7 @@ def _pred(self, encoded_inputs, target_masks):
532532
"""
533533
input_ids = encoded_inputs["input_ids"]
534534

535-
# <tf.float32>[batch_size, num_tokens]; ignore the last one in each row.
535+
# <tf.int32>[batch_size, num_tokens]; ignore the last one in each row.
536536
target_ids = tf.roll(encoded_inputs["input_ids"], shift=-1, axis=1)
537537
##
538538
# Process target masks
@@ -554,11 +554,11 @@ def _pred(self, encoded_inputs, target_masks):
554554
axis=0,
555555
)
556556

557-
padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32)
557+
padded_target_masks = tf.constant(padded_target_masks, dtype=tf.bool)
558558
# Shift masks back so they align with target_ids.
559559
loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1)
560560

561-
with tf.GradientTape(watch_accessed_variables=True) as tape:
561+
with tf.GradientTape(watch_accessed_variables=False) as tape:
562562
# We need to run the embedding layer ourselves so we can trace it.
563563
# See here for how the model normally does this:
564564
# http://google3/third_party/py/transformers/models/gpt2/modeling_tf_gpt2.py;l=450;rcl=578656271
@@ -574,18 +574,18 @@ def _pred(self, encoded_inputs, target_masks):
574574
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
575575
from_logits=True, reduction="none"
576576
)
577-
# <tf.float32>[batch_size, num_tokens]
577+
# <tf.float>[batch_size, num_tokens]
578578
per_token_loss = loss_fn(target_ids, out.logits)
579-
masked_loss = per_token_loss * loss_mask
579+
masked_loss = per_token_loss * tf.cast(loss_mask, per_token_loss.dtype)
580580

581581
grads = tape.gradient(
582582
masked_loss, embs
583-
) # <tf.float32>[batch_size, num_tokens, hdim]
583+
) # <tf.float>[batch_size, num_tokens, hdim]
584584

585-
grad_l2 = tf.norm(grads, axis=2) # <tf.float32>[batch_size, num_tokens]
585+
grad_l2 = tf.norm(grads, axis=2) # <tf.float>[batch_size, num_tokens]
586586
grad_dot_input = tf.reduce_sum(
587587
grads * embs, axis=2
588-
) # <tf.float32>[batch_size, num_tokens]
588+
) # <tf.float>[batch_size, num_tokens]
589589

590590
batched_outputs = {
591591
"input_ids": encoded_inputs["input_ids"],

0 commit comments

Comments
 (0)