Skip to content

Commit 7881f62

Browse files
VinozzZrobbkidd
andauthored
fix: use sendKey for rule matching if it's configured (#1735)
## Which problem is this PR solving? Refinery uses the API key included in an incoming request to determine the target environment and dataset. However, when a SendKey is configured, it can override the original incoming API key in the request. When SendKey is configured to override, Refinery should resolve the environment and dataset based on the SendKey, not the original incoming API key. Otherwise, sampling rules may be incorrectly matched using the wrong key. This PR ensures that when a SendKey is applied, the environment/dataset resolution and sampling rule matching are based on the SendKey. ## Short description of the changes - Replaces the request’s APIKey header value with the configured SendKey, when applicable - Adds integration tests to validate SendKey-based resolution behavior --------- Co-authored-by: Robb Kidd <[email protected]>
1 parent df4ca4c commit 7881f62

File tree

13 files changed

+503
-73
lines changed

13 files changed

+503
-73
lines changed

app/app_test.go

Lines changed: 422 additions & 36 deletions
Large diffs are not rendered by default.

collect/collect.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,18 +1351,6 @@ func (i *InMemCollector) sendTraces() {
13511351
i.Metrics.Histogram("collector_outgoing_queue", float64(len(i.outgoingTraces)))
13521352
_, span := otelutil.StartSpanMulti(context.Background(), i.Tracer, "sendTrace", map[string]interface{}{"num_spans": t.DescendantCount(), "outgoingTraces_size": len(i.outgoingTraces)})
13531353

1354-
// if we have a key replacement rule, we should
1355-
// replace the key with the new key
1356-
keycfg := i.Config.GetAccessKeyConfig()
1357-
overwriteWith, err := keycfg.GetReplaceKey(t.APIKey)
1358-
if err != nil {
1359-
i.Logger.Warn().Logf("error replacing key: %s", err.Error())
1360-
continue
1361-
}
1362-
if overwriteWith != t.APIKey {
1363-
t.APIKey = overwriteWith
1364-
}
1365-
13661354
for _, sp := range t.GetSpans() {
13671355
if sp.IsDecisionSpan() {
13681356
continue

collect/collect_test.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,6 @@ func TestAddRootSpan(t *testing.T) {
132132
IncomingQueueSize: 5,
133133
PeerQueueSize: 5,
134134
},
135-
GetAccessKeyConfigVal: config.AccessKeyConfig{
136-
SendKey: "another-key",
137-
SendKeyMode: "all",
138-
},
139135
}
140136
transmission := &transmit.MockTransmission{}
141137
transmission.Start()
@@ -170,7 +166,6 @@ func TestAddRootSpan(t *testing.T) {
170166
events := transmission.GetBlock(1)
171167
require.Equal(t, 1, len(events), "adding a root span should send the span")
172168
assert.Equal(t, "aoeu", events[0].Dataset, "sending a root span should immediately send that span via transmission")
173-
assert.Equal(t, "another-key", events[0].APIKey, "api key should be replaced with the send key")
174169

175170
assert.Nil(t, coll.getFromCache(traceID1), "after sending the span, it should be removed from the cache")
176171

@@ -190,7 +185,6 @@ func TestAddRootSpan(t *testing.T) {
190185
events = transmission.GetBlock(1)
191186
require.Equal(t, 1, len(events), "adding another root span should send the span")
192187
assert.Equal(t, "aoeu", events[0].Dataset, "sending a root span should immediately send that span via transmission")
193-
assert.Equal(t, "another-key", events[0].APIKey, "api key should be replaced with the send key")
194188

195189
assert.Nil(t, coll.getFromCache(traceID1), "after sending the span, it should be removed from the cache")
196190

config/file_config.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ type AccessKeyConfig struct {
9595
AcceptOnlyListedKeys bool `yaml:"AcceptOnlyListedKeys"`
9696
}
9797

98-
// IsAccepted checks if the given key is in the list of accepted keys.
99-
// if the key is not in the list, it returns an error with the key truncated to 8 characters for logging.
98+
// IsAccepted checks if the given key is in the list of received keys or a configured SendKey.
99+
// if not, it returns an error with the key truncated to 8 characters for logging.
100100
func (a *AccessKeyConfig) IsAccepted(key string) error {
101101
if a.AcceptOnlyListedKeys {
102-
if slices.Contains(a.ReceiveKeys, key) {
102+
if (len(a.SendKey) > 0 && key == a.SendKey) || slices.Contains(a.ReceiveKeys, key) {
103103
return nil
104104
}
105105

@@ -150,7 +150,10 @@ func (a *AccessKeyConfig) GetReplaceKey(apiKey string) (string, error) {
150150
}
151151
}
152152
}
153-
apiKey = overwriteWith
153+
154+
if overwriteWith != "" {
155+
apiKey = overwriteWith
156+
}
154157
}
155158

156159
if apiKey == "" {

config/file_config_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ func TestAccessKeyConfig_GetReplaceKey(t *testing.T) {
6767
}
6868
got, err := a.GetReplaceKey(tt.apiKey)
6969
if (err != nil) != tt.wantErr {
70-
t.Errorf("AccessKeyConfig.CheckAndMaybeReplaceKey() error = %v, wantErr %v", err, tt.wantErr)
70+
t.Errorf("AccessKeyConfig.GetReplaceKey() error = %v, wantErr %v", err, tt.wantErr)
7171
return
7272
}
7373
if got != tt.want {
74-
t.Errorf("AccessKeyConfig.CheckAndMaybeReplaceKey() = '%v', want '%v'", got, tt.want)
74+
t.Errorf("AccessKeyConfig.GetReplaceKey() = '%v', want '%v'", got, tt.want)
7575
}
7676
})
7777
}
@@ -93,7 +93,9 @@ func TestAccessKeyConfig_IsAccepted(t *testing.T) {
9393
{"no keys", fields{}, "key1", nil},
9494
{"known key", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true}, "key1", nil},
9595
{"unknown key", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true}, "key2", errors.New("api key key2... not found in list of authorized keys")},
96-
{"accept missing key", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true}, "", errors.New("api key ... not found in list of authorized keys")},
96+
{"reject missing key with sendkey configured", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true, SendKey: "key2"}, "", errors.New("api key ... not found in list of authorized keys")},
97+
{"reject missing key without sendkey configured", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true}, "", errors.New("api key ... not found in list of authorized keys")},
98+
{"accept sendkey", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true, SendKey: "key2"}, "key2", nil},
9799
}
98100
for _, tt := range tests {
99101
t.Run(tt.name, func(t *testing.T) {

config/mock.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ type MockConfig struct {
6262
CfgHash string
6363
RulesHash string
6464

65+
// Samplers allows per-dataset/environment sampler configuration
66+
// Map key is the dataset name or environment, value is the sampler choice
67+
Samplers map[string]*V2SamplerChoice
68+
6569
Mux sync.RWMutex
6670
}
6771

@@ -252,11 +256,27 @@ func (m *MockConfig) GetOTelTracingConfig() OTelTracingConfig {
252256
return m.GetOTelTracingConfigVal
253257
}
254258

255-
// TODO: allow per-dataset mock values
259+
// GetSamplerConfigForDestName returns the sampler config for the given dataset/environment.
260+
// If Samplers map is populated, it will look up the dataset-specific config.
261+
// Falls back to GetSamplerTypeVal if Samplers is not set (backwards compatible).
256262
func (m *MockConfig) GetSamplerConfigForDestName(dataset string) (interface{}, string) {
257263
m.Mux.RLock()
258264
defer m.Mux.RUnlock()
259265

266+
// If Samplers map is configured, use it (mimics fileConfig behavior)
267+
if m.Samplers != nil {
268+
// Try to find the specific dataset/environment
269+
if sampler, ok := m.Samplers[dataset]; ok {
270+
return sampler.Sampler()
271+
}
272+
273+
// Fall back to __default__
274+
if sampler, ok := m.Samplers["__default__"]; ok {
275+
return sampler.Sampler()
276+
}
277+
}
278+
279+
// Fall back to legacy behavior for backwards compatibility
260280
return m.GetSamplerTypeVal, m.GetSamplerTypeName
261281
}
262282

@@ -468,6 +488,31 @@ func (f *MockConfig) DetermineSamplerKey(apiKey, env, dataset string) string {
468488
}
469489

470490
func (f *MockConfig) GetSamplingKeyFieldsForDestName(samplerKey string) []string {
491+
f.Mux.RLock()
492+
defer f.Mux.RUnlock()
493+
494+
// If Samplers map is configured, use it to get the correct sampler
495+
if f.Samplers != nil {
496+
// Try specific dataset/environment first
497+
if sampler, ok := f.Samplers[samplerKey]; ok {
498+
if cfg, _ := sampler.Sampler(); cfg != nil {
499+
if fielder, ok := cfg.(GetSamplingFielder); ok {
500+
return fielder.GetSamplingFields()
501+
}
502+
}
503+
}
504+
505+
// Fall back to __default__
506+
if sampler, ok := f.Samplers["__default__"]; ok {
507+
if cfg, _ := sampler.Sampler(); cfg != nil {
508+
if fielder, ok := cfg.(GetSamplingFielder); ok {
509+
return fielder.GetSamplingFields()
510+
}
511+
}
512+
}
513+
}
514+
515+
// Fall back to legacy behavior
471516
switch sampler := f.GetSamplerTypeVal.(type) {
472517
case *DeterministicSamplerConfig:
473518
return sampler.GetSamplingFields()
@@ -482,5 +527,4 @@ func (f *MockConfig) GetSamplingKeyFieldsForDestName(samplerKey string) []string
482527
default:
483528
return nil
484529
}
485-
486530
}

route/middleware.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ func (r *Router) apiKeyProcessor(next http.Handler) http.Handler {
5050
return
5151
}
5252

53+
replacement, err := keycfg.GetReplaceKey(apiKey)
54+
if err != nil {
55+
r.handlerReturnWithError(w, ErrAuthInvalid, err)
56+
return
57+
}
58+
if replacement != apiKey {
59+
req.Header.Set(types.APIKeyHeader, replacement)
60+
}
61+
5362
next.ServeHTTP(w, req)
5463
})
5564
}

route/otlp_logs.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ func (r *Router) postOTLPLogs(w http.ResponseWriter, req *http.Request) {
1818

1919
ri := huskyotlp.GetRequestInfoFromHttpHeaders(req.Header)
2020
apicfg := r.Config.GetAccessKeyConfig()
21+
if err := apicfg.IsAccepted(ri.ApiKey); err != nil {
22+
r.handleOTLPFailureResponse(w, req, huskyotlp.OTLPError{Message: err.Error(), HTTPStatusCode: http.StatusUnauthorized})
23+
return
24+
}
2125
keyToUse, _ := apicfg.GetReplaceKey(ri.ApiKey)
2226

2327
if err := ri.ValidateLogsHeaders(); err != nil {

route/otlp_logs_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ func TestLogsOTLPHandler(t *testing.T) {
128128
for _, tC := range testCases {
129129
t.Run(tC.name, func(t *testing.T) {
130130
muxxer := mux.NewRouter()
131-
muxxer.Use(router.apiKeyProcessor)
132131
router.AddOTLPMuxxer(muxxer)
133132
server := httptest.NewServer(muxxer)
134133
defer server.Close()

route/otlp_trace.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ func (r *Router) postOTLPTrace(w http.ResponseWriter, req *http.Request) {
2626

2727
ri := huskyotlp.GetRequestInfoFromHttpHeaders(req.Header)
2828
apicfg := r.Config.GetAccessKeyConfig()
29+
if err := apicfg.IsAccepted(ri.ApiKey); err != nil {
30+
r.handleOTLPFailureResponse(w, req, huskyotlp.OTLPError{Message: err.Error(), HTTPStatusCode: http.StatusUnauthorized})
31+
return
32+
}
2933
keyToUse, _ := apicfg.GetReplaceKey(ri.ApiKey)
3034

3135
if err := ri.ValidateTracesHeaders(); err != nil {
@@ -49,11 +53,11 @@ func (r *Router) postOTLPTrace(w http.ResponseWriter, req *http.Request) {
4953
var err error
5054
switch ri.ContentType {
5155
case "application/json":
52-
r.Metrics.Increment(r.metricsNames.routerOtlpTraceHttpJson)
53-
err = r.processOTLPRequestWithMsgp(ctx, w, req, ri, keyToUse)
56+
r.Metrics.Increment(r.metricsNames.routerOtlpTraceHttpJson)
57+
err = r.processOTLPRequestWithMsgp(ctx, w, req, ri, keyToUse)
5458
case "application/x-protobuf", "application/protobuf":
5559
r.Metrics.Increment(r.metricsNames.routerOtlpTraceHttpProto)
56-
err = r.processOTLPRequestWithMsgp(ctx, w, req, ri, keyToUse)
60+
err = r.processOTLPRequestWithMsgp(ctx, w, req, ri, keyToUse)
5761
default:
5862
err = errors.New("unsupported content type")
5963
}

0 commit comments

Comments
 (0)