Skip to content

Commit 8d9bc00

Browse files
committed
review changes feedback
1 parent ac389bd commit 8d9bc00

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

connections/jetstream.go

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,11 @@ func (jc *jetstreamConsumer) As(i any) bool {
185185
}
186186

187187
if p, ok := i.(*jetstream.JetStream); ok {
188-
*p = jc.connector.Connection().(*jetstreamConnection).jetStream
189-
return true
188+
if jsConn, ok := jc.connector.Connection().(*jetstreamConnection); ok {
189+
*p = jsConn.jetStream
190+
return true
191+
}
192+
return false
190193
}
191194

192195
return false
@@ -235,29 +238,25 @@ func (jc *jetstreamConsumer) setActiveBatch(batch jetstream.MessageBatch) {
235238
}
236239

237240
func (jc *jetstreamConsumer) setupActiveBatch(ctx context.Context, batchCount int, batchTimeout time.Duration) (jetstream.MessageBatch, error) {
238-
// Fast path: check if we already have an active batch without write lock
241+
// Fast path: check if we already have an active batch with a read lock.
239242
if batch := jc.getActiveBatch(); batch != nil {
240243
return batch, nil
241244
}
242245

243-
// Check for context cancellation before expensive operations
244-
if err := ctx.Err(); err != nil {
245-
return nil, errorutil.Wrap(err, "context canceled while setting up batch")
246-
}
247-
248-
// Acquire write lock for the entire batch creation process to prevent
249-
// multiple goroutines from calling Fetch() concurrently (which would
250-
// cause duplicate batches and lost messages)
246+
// Acquire a write lock to create the batch.
251247
jc.mu.Lock()
252248
defer jc.mu.Unlock()
253249

254-
// Double-check after acquiring write lock
250+
// Double-check after acquiring the write lock in case another goroutine created it.
255251
if jc.activeBatch != nil {
256252
return jc.activeBatch, nil
257253
}
258254

259-
// Perform fetch while holding lock - this blocks other receivers but
260-
// prevents the race condition of multiple concurrent Fetch() calls
255+
// Check for context cancellation before the blocking call.
256+
if err := ctx.Err(); err != nil {
257+
return nil, errorutil.Wrap(err, "context canceled while setting up batch")
258+
}
259+
261260
batch, err := jc.consumer.Fetch(batchCount, jetstream.FetchMaxWait(batchTimeout))
262261
if err != nil {
263262
if errors.Is(err, nats.ErrConnectionClosed) || errors.Is(err, nats.ErrConnectionDraining) {
@@ -267,7 +266,7 @@ func (jc *jetstreamConsumer) setupActiveBatch(ctx context.Context, batchCount in
267266
}
268267

269268
jc.activeBatch = batch
270-
return batch, nil
269+
return jc.activeBatch, nil
271270
}
272271

273272
func (jc *jetstreamConsumer) clearActiveBatch() {
@@ -285,13 +284,12 @@ func (jc *jetstreamConsumer) pullMessages(ctx context.Context, batchCount int, b
285284
for {
286285
select {
287286
case <-ctx.Done():
288-
// If we already have messages, return them instead of error
289-
if len(messages) > 0 {
290-
return messages, nil
291-
}
287+
// Timeout is not an error - return messages collected so far
292288
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
293-
return nil, errorutil.Wrap(ctx.Err(), "timeout while waiting for messages")
289+
return messages, nil
294290
}
291+
// Return messages along with cancellation error to signal unhealthy source
292+
// per gocloud.dev/pubsub/driver ReceiveBatch semantics
295293
return messages, errorutil.Wrap(ctx.Err(), "context canceled while processing messages")
296294

297295
case msg, ok := <-activeBatch.Messages():

connections/plain.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,11 @@ func (q *natsConsumer) As(i any) bool {
193193
}
194194

195195
if p, ok := i.(**nats.Conn); ok {
196-
*p = q.connector.Connection().(*plainConnection).natsConnection
197-
return true
196+
if plainConn, ok := q.connector.Connection().(*plainConnection); ok {
197+
*p = plainConn.natsConnection
198+
return true
199+
}
200+
return false
198201
}
199202

200203
if p, ok := i.(*Connector); ok {
@@ -291,10 +294,9 @@ func (q *natsConsumer) ReceiveMessages(ctx context.Context, batchCount int) ([]*
291294
if errors.Is(err, nats.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
292295
return messages, nil
293296
}
294-
if len(messages) > 0 {
295-
return messages, nil
296-
}
297-
return nil, errorutil.Wrap(err, "error receiving message")
297+
// Return any messages we have along with the error to signal unhealthy source
298+
// per gocloud.dev/pubsub/driver ReceiveBatch semantics
299+
return messages, errorutil.Wrap(err, "error receiving message")
298300
}
299301

300302
var driverMsg *driver.Message

0 commit comments

Comments
 (0)