@@ -185,8 +185,11 @@ func (jc *jetstreamConsumer) As(i any) bool {
185185 }
186186
187187 if p , ok := i .(* jetstream.JetStream ); ok {
188- * p = jc .connector .Connection ().(* jetstreamConnection ).jetStream
189- return true
188+ if jsConn , ok := jc .connector .Connection ().(* jetstreamConnection ); ok {
189+ * p = jsConn .jetStream
190+ return true
191+ }
192+ return false
190193 }
191194
192195 return false
@@ -235,29 +238,25 @@ func (jc *jetstreamConsumer) setActiveBatch(batch jetstream.MessageBatch) {
235238}
236239
237240func (jc * jetstreamConsumer ) setupActiveBatch (ctx context.Context , batchCount int , batchTimeout time.Duration ) (jetstream.MessageBatch , error ) {
238- // Fast path: check if we already have an active batch without write lock
241+ // Fast path: check if we already have an active batch with a read lock.
239242 if batch := jc .getActiveBatch (); batch != nil {
240243 return batch , nil
241244 }
242245
243- // Check for context cancellation before expensive operations
244- if err := ctx .Err (); err != nil {
245- return nil , errorutil .Wrap (err , "context canceled while setting up batch" )
246- }
247-
248- // Acquire write lock for the entire batch creation process to prevent
249- // multiple goroutines from calling Fetch() concurrently (which would
250- // cause duplicate batches and lost messages)
246+ // Acquire a write lock to create the batch.
251247 jc .mu .Lock ()
252248 defer jc .mu .Unlock ()
253249
254- // Double-check after acquiring write lock
250+ // Double-check after acquiring the write lock in case another goroutine created it.
255251 if jc .activeBatch != nil {
256252 return jc .activeBatch , nil
257253 }
258254
259- // Perform fetch while holding lock - this blocks other receivers but
260- // prevents the race condition of multiple concurrent Fetch() calls
255+ // Check for context cancellation before the blocking call.
256+ if err := ctx .Err (); err != nil {
257+ return nil , errorutil .Wrap (err , "context canceled while setting up batch" )
258+ }
259+
261260 batch , err := jc .consumer .Fetch (batchCount , jetstream .FetchMaxWait (batchTimeout ))
262261 if err != nil {
263262 if errors .Is (err , nats .ErrConnectionClosed ) || errors .Is (err , nats .ErrConnectionDraining ) {
@@ -267,7 +266,7 @@ func (jc *jetstreamConsumer) setupActiveBatch(ctx context.Context, batchCount in
267266 }
268267
269268 jc .activeBatch = batch
270- return batch , nil
269+ return jc . activeBatch , nil
271270}
272271
273272func (jc * jetstreamConsumer ) clearActiveBatch () {
@@ -285,13 +284,12 @@ func (jc *jetstreamConsumer) pullMessages(ctx context.Context, batchCount int, b
285284 for {
286285 select {
287286 case <- ctx .Done ():
288- // If we already have messages, return them instead of error
289- if len (messages ) > 0 {
290- return messages , nil
291- }
287+ // Timeout is not an error - return messages collected so far
292288 if errors .Is (ctx .Err (), context .DeadlineExceeded ) {
293- return nil , errorutil . Wrap ( ctx . Err (), "timeout while waiting for messages" )
289+ return messages , nil
294290 }
291+ // Return messages along with cancellation error to signal unhealthy source
292+ // per gocloud.dev/pubsub/driver ReceiveBatch semantics
295293 return messages , errorutil .Wrap (ctx .Err (), "context canceled while processing messages" )
296294
297295 case msg , ok := <- activeBatch .Messages ():
0 commit comments