Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions internal/database/sqlcommon/sqlcommon.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package sqlcommon
import (
"context"
"database/sql"
"fmt"

sq "github.com/Masterminds/squirrel"
"github.com/golang-migrate/migrate/v4"
Expand Down Expand Up @@ -177,15 +178,18 @@ func (s *SQLCommon) query(ctx context.Context, q sq.SelectBuilder) (*sql.Rows, *
return s.queryTx(ctx, nil, q)
}

func (s *SQLCommon) countQuery(ctx context.Context, tx *txWrapper, tableName string, fop sq.Sqlizer) (count int64, err error) {
func (s *SQLCommon) countQuery(ctx context.Context, tx *txWrapper, tableName string, fop sq.Sqlizer, countExpr string) (count int64, err error) {
count = -1
l := log.L(ctx)
if tx == nil {
// If there is a transaction in the context, we should use it to provide consistency
// in the read operations (read after insert for example).
tx = getTXFromContext(ctx)
}
q := sq.Select("COUNT(*)").From(tableName).Where(fop)
if countExpr == "" {
countExpr = "*"
}
q := sq.Select(fmt.Sprintf("COUNT(%s)", countExpr)).From(tableName).Where(fop)
sqlQuery, args, err := q.PlaceholderFormat(s.provider.PlaceholderFormat()).ToSql()
if err != nil {
return count, i18n.WrapError(ctx, err, i18n.MsgDBQueryBuildFailed)
Expand Down Expand Up @@ -214,15 +218,14 @@ func (s *SQLCommon) countQuery(ctx context.Context, tx *txWrapper, tableName str

func (s *SQLCommon) queryRes(ctx context.Context, tx *txWrapper, tableName string, fop sq.Sqlizer, fi *database.FilterInfo) *database.FilterResult {
fr := &database.FilterResult{}
if !fi.Count {
return fr
}
count, err := s.countQuery(ctx, tx, tableName, fop)
if err != nil {
// Log, but continue
log.L(ctx).Warnf("Unable to return count for query: %s", err)
if fi.Count {
count, err := s.countQuery(ctx, tx, tableName, fop, fi.CountExpr)
if err != nil {
// Log, but continue
log.L(ctx).Warnf("Unable to return count for query: %s", err)
}
fr.TotalCount = &count // could be -1 if the count extract fails - we still return the result
}
fr.TotalCount = &count // could be -1 if the count extract fails - we still return the result
return fr
}

Expand Down
18 changes: 13 additions & 5 deletions internal/database/sqlcommon/sqlcommon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,27 +272,35 @@ func TestTXConcurrency(t *testing.T) {

func TestCountQueryBadSQL(t *testing.T) {
s, _ := newMockProvider().init()
_, err := s.countQuery(context.Background(), nil, "", sq.Insert("wrong"))
_, err := s.countQuery(context.Background(), nil, "", sq.Insert("wrong"), "")
assert.Regexp(t, "FF10113", err)
}

func TestCountQueryQueryFailed(t *testing.T) {
s, mdb := newMockProvider().init()
mdb.ExpectQuery("SELECT COUNT.*").WillReturnError(fmt.Errorf("pop"))
_, err := s.countQuery(context.Background(), nil, "table1", sq.Eq{"col1": "val1"})
mdb.ExpectQuery("^SELECT COUNT\\(\\*\\)").WillReturnError(fmt.Errorf("pop"))
_, err := s.countQuery(context.Background(), nil, "table1", sq.Eq{"col1": "val1"}, "")
assert.Regexp(t, "FF10115.*pop", err)
}

func TestCountQueryScanFailTx(t *testing.T) {
s, mdb := newMockProvider().init()
mdb.ExpectBegin()
mdb.ExpectQuery("SELECT COUNT.*").WillReturnRows(sqlmock.NewRows([]string{"col1"}).AddRow("not a number"))
mdb.ExpectQuery("^SELECT COUNT\\(\\*\\)").WillReturnRows(sqlmock.NewRows([]string{"col1"}).AddRow("not a number"))
ctx, tx, _, err := s.beginOrUseTx(context.Background())
assert.NoError(t, err)
_, err = s.countQuery(ctx, tx, "table1", sq.Eq{"col1": "val1"})
_, err = s.countQuery(ctx, tx, "table1", sq.Eq{"col1": "val1"}, "")
assert.Regexp(t, "FF10121", err)
}

func TestCountQueryWithExpr(t *testing.T) {
s, mdb := newMockProvider().init()
mdb.ExpectQuery("^SELECT COUNT\\(DISTINCT key\\)").WillReturnRows(sqlmock.NewRows([]string{"col1"}).AddRow(10))
_, err := s.countQuery(context.Background(), nil, "table1", sq.Eq{"col1": "val1"}, "DISTINCT key")
assert.NoError(t, err)
assert.NoError(t, mdb.ExpectationsWereMet())
}

func TestQueryResSwallowError(t *testing.T) {
s, _ := newMockProvider().init()
res := s.queryRes(context.Background(), nil, "", sq.Insert("wrong"), &database.FilterInfo{
Expand Down
7 changes: 5 additions & 2 deletions internal/database/sqlcommon/tokenbalance_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ func (s *SQLCommon) GetTokenAccounts(ctx context.Context, filter database.Filter
if err != nil {
return nil, nil, err
}
fi.CountExpr = "DISTINCT key"

rows, tx, err := s.query(ctx, query)
if err != nil {
Expand All @@ -220,11 +221,13 @@ func (s *SQLCommon) GetTokenAccounts(ctx context.Context, filter database.Filter

func (s *SQLCommon) GetTokenAccountPools(ctx context.Context, key string, filter database.Filter) ([]*fftypes.TokenAccountPool, *database.FilterResult, error) {
query, fop, fi, err := s.filterSelect(ctx, "",
sq.Select("pool_id").Distinct().From("tokenbalance").Where(sq.Eq{"key": key}),
filter, tokenBalanceFilterFieldMap, []interface{}{"seq"})
sq.Select("pool_id").Distinct().From("tokenbalance"),
filter, tokenBalanceFilterFieldMap, []interface{}{"seq"},
sq.Eq{"key": key})
if err != nil {
return nil, nil, err
}
fi.CountExpr = "DISTINCT pool_id"

rows, tx, err := s.query(ctx, query)
if err != nil {
Expand Down
19 changes: 10 additions & 9 deletions pkg/database/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,16 @@ type SortField struct {
// FilterInfo is the structure returned by Finalize to the plugin, to serialize this filter
// into the underlying database mechanism's filter language
type FilterInfo struct {
Sort []*SortField
Skip uint64
Limit uint64
Count bool
Field string
Op FilterOp
Values []FieldSerialization
Value FieldSerialization
Children []*FilterInfo
Sort []*SortField
Skip uint64
Limit uint64
Count bool
CountExpr string
Field string
Op FilterOp
Values []FieldSerialization
Value FieldSerialization
Children []*FilterInfo
}

// FilterResult is has additional info if requested on the query - currently only the total count
Expand Down