@@ -68,8 +68,8 @@ def __init__(
6868 self .model .preprocessor .tokenizer .id_to_token
6969 )
7070
71- # map ids: <tf.int64 >[batch_size, num_tokens]
72- # to embs: <tf.float32 >[batch_size, num_tokens, emb_dim]
71+ # map ids: <tf.int >[batch_size, num_tokens]
72+ # to embs: <tf.float >[batch_size, num_tokens, emb_dim]
7373 self .embedder = self .model .backbone .token_embedding
7474
7575 @classmethod
@@ -114,7 +114,7 @@ def embed_texts(self, texts: Sequence[str]):
114114 processed_inputs = self .encode_inputs (
115115 texts , sequence_length = self .max_length
116116 )
117- # <tf.float32 >[batch_size, num_tokens, emb_dim]
117+ # <tf.float >[batch_size, num_tokens, emb_dim]
118118 embs = self .embedder (processed_inputs ["token_ids" ])
119119 # <tf.bool>[batch_size, num_tokens]
120120 mask = processed_inputs ["padding_mask" ]
@@ -123,13 +123,13 @@ def embed_texts(self, texts: Sequence[str]):
123123 def embed_and_mean_pool (self , texts : Sequence [str ]):
124124 """Return a single vector for each text."""
125125 embs , mask = self .embed_texts (texts )
126- # <tf.float32 >[batch_size, num_tokens, 1]
127- mask = tf .expand_dims (tf .cast (mask , dtype = tf . float32 ), axis = 2 )
128- # <tf.float32 >[batch_size, 1, emb_dim]
126+ # <tf.float >[batch_size, num_tokens, 1]
127+ mask = tf .expand_dims (tf .cast (mask , dtype = embs . dtype ), axis = 2 )
128+ # <tf.float >[batch_size, 1, emb_dim]
129129 pooled_embs = tf .reduce_sum (
130130 mask * embs , axis = 1 , keepdims = True
131131 ) / tf .reduce_sum (mask , axis = 1 , keepdims = True )
132- # <tf.float32 >[batch_size, emb_dim]
132+ # <tf.float >[batch_size, emb_dim]
133133 return tf .squeeze (pooled_embs , axis = 1 )
134134
135135 def predict_minibatch (
@@ -203,7 +203,7 @@ def __init__(self, *args, **kw):
203203
204204 def _pred (self , input_ids , padding_mask , target_masks ):
205205 """Predict a batch of tokenized text."""
206- # <tf.float32 >[batch_size, num_tokens]; ignore the last one in each row.
206+ # <tf.int >[batch_size, num_tokens]; ignore the last one in each row.
207207 target_ids = tf .roll (input_ids , shift = - 1 , axis = 1 )
208208
209209 ##
@@ -226,13 +226,13 @@ def _pred(self, input_ids, padding_mask, target_masks):
226226 axis = 0 ,
227227 )
228228
229- padded_target_masks = tf .constant (padded_target_masks , dtype = tf .float32 )
229+ padded_target_masks = tf .constant (padded_target_masks , dtype = tf .bool )
230230 # Shift masks back so they align with target_ids.
231231 loss_mask = tf .roll (padded_target_masks , shift = - 1 , axis = 1 )
232232
233233 embeddings = None
234234
235- with tf .GradientTape (watch_accessed_variables = True ) as tape :
235+ with tf .GradientTape (watch_accessed_variables = False ) as tape :
236236
237237 def layer_intercept_fn (x , i ):
238238 if i == - 1 :
@@ -241,21 +241,21 @@ def layer_intercept_fn(x, i):
241241 tape .watch (embeddings )
242242 return x
243243
244- # <tf.float32 >[batch_size, num_tokens]
244+ # <tf.float >[batch_size, num_tokens]
245245 per_token_loss = self .model .score (
246246 token_ids = input_ids ,
247247 padding_mask = padding_mask ,
248248 scoring_mode = "loss" ,
249249 layer_intercept_fn = layer_intercept_fn ,
250250 target_ids = target_ids ,
251251 )
252- masked_loss = per_token_loss * loss_mask
252+ masked_loss = per_token_loss * tf . cast ( loss_mask , per_token_loss . dtype )
253253
254- # <tf.float32 >[batch_size, num_tokens, hdim]
254+ # <tf.float >[batch_size, num_tokens, hdim]
255255 grads = tape .gradient (masked_loss , embeddings )
256- # <tf.float32 >[batch_size, num_tokens]
256+ # <tf.float >[batch_size, num_tokens]
257257 grad_l2 = tf .norm (grads , axis = 2 )
258- # <tf.float32 >[batch_size, num_tokens]
258+ # <tf.float >[batch_size, num_tokens]
259259 grad_dot_input = tf .reduce_sum (grads * embeddings , axis = 2 )
260260
261261 batched_outputs = {
0 commit comments