Skip to content

Commit cc3372d

Browse files
TropicalDog17melekesmergify[bot]
authored
perf(internal/bits): Additional speedup to bitArray.PickRandom (cometbft#4675)
Closes cometbft#2849 before & after, we observed roughly 30% speed up with this refactor ```goos: linux goarch: amd64 pkg: github.com/cometbft/cometbft/internal/bits cpu: 13th Gen Intel(R) Core(TM) i5-13400 BenchmarkPickRandomBitArray-16 15392144 65.31 ns/op 0 B/op 0 allocs/op PASS ``` ```goos: linux goarch: amd64 pkg: github.com/cometbft/cometbft/internal/bits cpu: 13th Gen Intel(R) Core(TM) i5-13400 BenchmarkPickRandomBitArray-16 24890334 44.16 ns/op 0 B/op 0 allocs/op PASS ``` --- #### PR checklist - [x] Tests written/updated - [ ] Changelog entry added in `.changelog` (we use [unclog](https://github.com/informalsystems/unclog) to manage our changelog) - [ ] Updated relevant documentation (`docs/` or `spec/`) and code comments --------- Co-authored-by: Anton Kaliaev <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 21024b7 commit cc3372d

File tree

3 files changed

+64
-35
lines changed

3 files changed

+64
-35
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- Additional speedup to bitArray.PickRandom
2+
([\#2849](https://github.com/cometbft/cometbft/issues/2849))

internal/bits/bit_array.go

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ import (
1515

1616
// BitArray is a thread-safe implementation of a bit array.
1717
type BitArray struct {
18-
mtx sync.Mutex
19-
Bits int `json:"bits"` // NOTE: persisted via reflect, must be exported
20-
Elems []uint64 `json:"elems"` // NOTE: persisted via reflect, must be exported
18+
mtx sync.Mutex
19+
TrueBitCount int `json:"true_bits_count"` // Number of bits set to true
20+
Bits int `json:"bits"` // NOTE: persisted via reflect, must be exported
21+
Elems []uint64 `json:"elems"` // NOTE: persisted via reflect, must be exported
2122
}
2223

2324
// NewBitArray returns a new bit array.
@@ -27,8 +28,9 @@ func NewBitArray(bits int) *BitArray {
2728
return nil
2829
}
2930
return &BitArray{
30-
Bits: bits,
31-
Elems: make([]uint64, (bits+63)/64),
31+
Bits: bits,
32+
Elems: make([]uint64, (bits+63)/64),
33+
TrueBitCount: 0,
3234
}
3335
}
3436

@@ -40,13 +42,15 @@ func NewBitArrayFromFn(bits int, fn func(int) bool) *BitArray {
4042
return nil
4143
}
4244
bA := &BitArray{
43-
Bits: bits,
44-
Elems: make([]uint64, (bits+63)/64),
45+
Bits: bits,
46+
Elems: make([]uint64, (bits+63)/64),
47+
TrueBitCount: 0,
4548
}
4649
for i := 0; i < bits; i++ {
4750
v := fn(i)
4851
if v {
4952
bA.Elems[i/64] |= (uint64(1) << uint(i%64))
53+
bA.TrueBitCount++
5054
}
5155
}
5256
return bA
@@ -93,9 +97,18 @@ func (bA *BitArray) setIndex(i int, v bool) bool {
9397
if i >= bA.Bits {
9498
return false
9599
}
100+
// Check current bit value
101+
oldValue := bA.getIndex(i)
102+
96103
if v {
104+
if !oldValue {
105+
bA.TrueBitCount++
106+
}
97107
bA.Elems[i/64] |= (uint64(1) << uint(i%64))
98108
} else {
109+
if oldValue {
110+
bA.TrueBitCount--
111+
}
99112
bA.Elems[i/64] &= ^(uint64(1) << uint(i%64))
100113
}
101114
return true
@@ -115,17 +128,27 @@ func (bA *BitArray) copy() *BitArray {
115128
c := make([]uint64, len(bA.Elems))
116129
copy(c, bA.Elems)
117130
return &BitArray{
118-
Bits: bA.Bits,
119-
Elems: c,
131+
Bits: bA.Bits,
132+
Elems: c,
133+
TrueBitCount: bA.TrueBitCount,
120134
}
121135
}
122136

123137
func (bA *BitArray) copyBits(bits int) *BitArray {
124138
c := make([]uint64, (bits+63)/64)
125139
copy(c, bA.Elems)
140+
141+
// Calculate true bit count for the new size
142+
newTrueBitCount := 0
143+
for i := 0; i < bits; i++ {
144+
if c[i/64]&(uint64(1)<<uint(i%64)) > 0 {
145+
newTrueBitCount++
146+
}
147+
}
126148
return &BitArray{
127-
Bits: bits,
128-
Elems: c,
149+
Bits: bits,
150+
Elems: c,
151+
TrueBitCount: newTrueBitCount,
129152
}
130153
}
131154

@@ -193,6 +216,9 @@ func (bA *BitArray) not() *BitArray {
193216
for i := 0; i < len(c.Elems); i++ {
194217
c.Elems[i] = ^c.Elems[i]
195218
}
219+
220+
// Flip count is simply total bits minus current true bits
221+
c.TrueBitCount = c.Bits - c.TrueBitCount
196222
return c
197223
}
198224

@@ -268,36 +294,18 @@ func (bA *BitArray) PickRandom(r *rand.Rand) (int, bool) {
268294
}
269295

270296
bA.mtx.Lock()
271-
numTrueIndices := bA.getNumTrueIndices()
272-
if numTrueIndices == 0 { // no bits set to true
297+
if bA.TrueBitCount == 0 { // no bits set to true
273298
bA.mtx.Unlock()
274299
return 0, false
275300
}
276-
index := bA.getNthTrueIndex(r.Intn(numTrueIndices))
301+
index := bA.getNthTrueIndex(r.Intn(bA.TrueBitCount))
277302
bA.mtx.Unlock()
278303
if index == -1 {
279304
return 0, false
280305
}
281306
return index, true
282307
}
283308

284-
func (bA *BitArray) getNumTrueIndices() int {
285-
count := 0
286-
numElems := len(bA.Elems)
287-
// handle all elements except the last one
288-
for i := 0; i < numElems-1; i++ {
289-
count += bits.OnesCount64(bA.Elems[i])
290-
}
291-
// handle last element
292-
numFinalBits := bA.Bits - (numElems-1)*64
293-
for i := 0; i < numFinalBits; i++ {
294-
if (bA.Elems[numElems-1] & (uint64(1) << uint64(i))) > 0 {
295-
count++
296-
}
297-
}
298-
return count
299-
}
300-
301309
// getNthTrueIndex returns the index of the nth true bit in the bit array.
302310
// n is 0 indexed. (e.g. for bitarray x__x, getNthTrueIndex(0) returns 0).
303311
// If there is no such value, it returns -1.
@@ -442,6 +450,7 @@ func (bA *BitArray) UnmarshalJSON(bz []byte) error {
442450
// into a pointer with pre-allocated BitArray.
443451
bA.Bits = 0
444452
bA.Elems = nil
453+
bA.TrueBitCount = 0
445454
return nil
446455
}
447456

@@ -459,6 +468,7 @@ func (bA *BitArray) UnmarshalJSON(bz []byte) error {
459468
// Treat it as if we encountered the case: b == "null"
460469
bA.Bits = 0
461470
bA.Elems = nil
471+
bA.TrueBitCount = 0
462472
return nil
463473
}
464474

@@ -468,9 +478,18 @@ func (bA *BitArray) UnmarshalJSON(bz []byte) error {
468478
}
469479
}
470480

481+
trueCount := 0
482+
for i := 0; i < numBits; i++ {
483+
if bits[i] == 'x' {
484+
bA2.SetIndex(i, true)
485+
trueCount++
486+
}
487+
}
488+
471489
// Instead of *bA = *bA2
472490
bA.Bits = bA2.Bits
473491
bA.Elems = make([]uint64, len(bA2.Elems))
492+
bA.TrueBitCount = trueCount
474493
copy(bA.Elems, bA2.Elems)
475494
return nil
476495
}
@@ -498,5 +517,13 @@ func (bA *BitArray) FromProto(protoBitArray *cmtprotobits.BitArray) {
498517
bA.Bits = int(protoBitArray.Bits)
499518
if len(protoBitArray.Elems) > 0 {
500519
bA.Elems = protoBitArray.Elems
520+
521+
// Recalculate TrueBitCount
522+
bA.TrueBitCount = 0
523+
for i := 0; i < bA.Bits; i++ {
524+
if bA.getIndex(i) {
525+
bA.TrueBitCount++
526+
}
527+
}
501528
}
502529
}

internal/bits/bit_array_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ func TestOr(t *testing.T) {
7676
t.Error("Wrong bit from bA3", i, bA1.GetIndex(i), bA2.GetIndex(i), bA3.GetIndex(i))
7777
}
7878
}
79-
if bA3.getNumTrueIndices() == 0 {
79+
if bA3.TrueBitCount == 0 {
8080
t.Error("Expected at least one true bit. " +
8181
"This has a false positive rate that is less than 1 in 2^80 (cryptographically improbable).")
8282
}
@@ -170,10 +170,10 @@ func TestGetNumTrueIndices(t *testing.T) {
170170
var bitArr *BitArray
171171
err := json.Unmarshal([]byte(`"`+tc.Input+`"`), &bitArr)
172172
require.NoError(t, err)
173-
result := bitArr.getNumTrueIndices()
173+
result := bitArr.TrueBitCount
174174
require.Equal(t, tc.ExpectedResult, result, "for input %s, expected %d, got %d", tc.Input, tc.ExpectedResult, result)
175-
result = bitArr.Not().getNumTrueIndices()
176-
require.Equal(t, bitArr.Bits-result, bitArr.getNumTrueIndices())
175+
result = bitArr.Not().TrueBitCount
176+
require.Equal(t, bitArr.Bits-result, bitArr.TrueBitCount)
177177
}
178178
}
179179

0 commit comments

Comments
 (0)