diff --git a/exporter/signalfxexporter/dpclient.go b/exporter/signalfxexporter/dpclient.go index cdb7338c1bf8..90ea212e7fae 100644 --- a/exporter/signalfxexporter/dpclient.go +++ b/exporter/signalfxexporter/dpclient.go @@ -15,6 +15,7 @@ import ( "sync" sfxpb "github.com/signalfx/com_signalfx_metrics_protobuf/model" + "go.opentelemetry.io/collector/client" "go.opentelemetry.io/collector/consumer/consumererror" "go.opentelemetry.io/collector/pdata/pmetric" "go.opentelemetry.io/collector/pdata/pmetric/pmetricotlp" @@ -88,7 +89,7 @@ func (s *sfxDPClient) pushMetricsData( } // All metrics in the pmetric.Metrics will have the same access token because of the BatchPerResourceMetrics. - metricToken := s.retrieveAccessToken(rms.At(0)) + metricToken := s.retrieveAccessToken(ctx, rms.At(0)) // export SFx format sfxDataPoints := s.converter.MetricsToSignalFxV2(md) @@ -194,12 +195,18 @@ func (s *sfxDPClient) encodeBody(dps []*sfxpb.DataPoint) (bodyReader io.Reader, return s.getReader(body) } -func (s *sfxDPClient) retrieveAccessToken(md pmetric.ResourceMetrics) string { +func (s *sfxDPClient) retrieveAccessToken(ctx context.Context, md pmetric.ResourceMetrics) string { if !s.accessTokenPassthrough { // Nothing to do if token is pass through not configured or resource is nil. return "" } + cl := client.FromContext(ctx) + ss := cl.Metadata.Get(splunk.SFxAccessTokenHeader) + if len(ss) > 0 { + return ss[0] + } + attrs := md.Resource().Attributes() if accessToken, ok := attrs.Get(splunk.SFxAccessTokenLabel); ok { return accessToken.Str() diff --git a/exporter/signalfxexporter/eventclient.go b/exporter/signalfxexporter/eventclient.go index 8bb12082cfa1..c6471e602d5b 100644 --- a/exporter/signalfxexporter/eventclient.go +++ b/exporter/signalfxexporter/eventclient.go @@ -11,6 +11,7 @@ import ( "strings" sfxpb "github.com/signalfx/com_signalfx_metrics_protobuf/model" + "go.opentelemetry.io/collector/client" "go.opentelemetry.io/collector/consumer/consumererror" "go.opentelemetry.io/collector/pdata/pcommon" "go.opentelemetry.io/collector/pdata/plog" @@ -33,7 +34,7 @@ func (s *sfxEventClient) pushLogsData(ctx context.Context, ld plog.Logs) (int, e return 0, nil } - accessToken := s.retrieveAccessToken(rls.At(0)) + accessToken := s.retrieveAccessToken(ctx, rls.At(0)) var sfxEvents []*sfxpb.Event numDroppedLogRecords := 0 @@ -104,7 +105,18 @@ func (s *sfxEventClient) encodeBody(events []*sfxpb.Event) (bodyReader io.Reader return s.getReader(body) } -func (s *sfxEventClient) retrieveAccessToken(rl plog.ResourceLogs) string { +func (s *sfxEventClient) retrieveAccessToken(ctx context.Context, rl plog.ResourceLogs) string { + if !s.accessTokenPassthrough { + // Nothing to do if token is pass through not configured or resource is nil. + return "" + } + + cl := client.FromContext(ctx) + ss := cl.Metadata.Get(splunk.SFxAccessTokenHeader) + if len(ss) > 0 { + return ss[0] + } + attrs := rl.Resource().Attributes() if accessToken, ok := attrs.Get(splunk.SFxAccessTokenLabel); ok && accessToken.Type() == pcommon.ValueTypeStr { return accessToken.Str() diff --git a/exporter/signalfxexporter/exporter_test.go b/exporter/signalfxexporter/exporter_test.go index 5da4d7109eb8..8dd6d4cccf1e 100644 --- a/exporter/signalfxexporter/exporter_test.go +++ b/exporter/signalfxexporter/exporter_test.go @@ -23,6 +23,7 @@ import ( sfxpb "github.com/signalfx/com_signalfx_metrics_protobuf/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/collector/client" "go.opentelemetry.io/collector/component/componenttest" "go.opentelemetry.io/collector/config/confighttp" "go.opentelemetry.io/collector/config/configopaque" @@ -567,6 +568,138 @@ func TestConsumeMetricsWithAccessTokenPassthrough(t *testing.T) { } } +func TestConsumeMetricsAccessTokenPassthroughPriorityToContext(t *testing.T) { + fromHeaders := "AccessTokenFromClientHeaders" + fromLabels := []string{"AccessTokenFromLabel0", "AccessTokenFromLabel1"} + fromContext := "AccessTokenFromContext" + + validMetricsWithToken := func(includeToken bool, token string, histogram bool) pmetric.Metrics { + out := pmetric.NewMetrics() + rm := out.ResourceMetrics().AppendEmpty() + + if includeToken { + rm.Resource().Attributes().PutStr("com.splunk.signalfx.access_token", token) + } + + ilm := rm.ScopeMetrics().AppendEmpty() + m := ilm.Metrics().AppendEmpty() + + if histogram { + buildHistogram(m, "test_histogram", pcommon.Timestamp(100000000), 1) + } else { + m.SetName("test_gauge") + + dp := m.SetEmptyGauge().DataPoints().AppendEmpty() + dp.Attributes().PutStr("k0", "v0") + dp.Attributes().PutStr("k1", "v1") + dp.SetDoubleValue(123) + } + + return out + } + + tests := []struct { + name string + accessTokenPassthrough bool + metrics pmetric.Metrics + additionalHeaders map[string]string + pushedTokens []string + sendOTLPHistograms bool + inContext bool + }{ + { + name: "passthrough access token and included in md", + accessTokenPassthrough: true, + inContext: true, + metrics: validMetricsWithToken(true, fromLabels[0], false), + pushedTokens: []string{fromContext}, + }, + { + name: "passthrough access token and not included in md", + accessTokenPassthrough: true, + inContext: true, + metrics: validMetricsWithToken(false, fromLabels[0], false), + pushedTokens: []string{fromContext}, + sendOTLPHistograms: false, + }, + { + name: "passthrough access token and included in md", + accessTokenPassthrough: true, + inContext: false, + metrics: validMetricsWithToken(true, fromLabels[0], false), + pushedTokens: []string{fromLabels[0]}, + }, + { + name: "passthrough access token and not included in md", + accessTokenPassthrough: true, + inContext: false, + metrics: validMetricsWithToken(false, fromLabels[0], false), + pushedTokens: []string{fromHeaders}, + sendOTLPHistograms: false, + }, + } + for _, tt := range tests { + receivedTokens := struct { + sync.Mutex + tokens []string + }{} + receivedTokens.tokens = []string{} + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, tt.name, r.Header.Get("test_header_")) + receivedTokens.Lock() + + token := r.Header.Get("x-sf-token") + receivedTokens.tokens = append(receivedTokens.tokens, token) + + receivedTokens.Unlock() + w.WriteHeader(http.StatusAccepted) + })) + defer server.Close() + + factory := NewFactory() + cfg := factory.CreateDefaultConfig().(*Config) + cfg.IngestURL = server.URL + cfg.APIURL = server.URL + cfg.ClientConfig.Headers = make(map[string]configopaque.String) + for k, v := range tt.additionalHeaders { + cfg.ClientConfig.Headers[k] = configopaque.String(v) + } + cfg.ClientConfig.Headers["test_header_"] = configopaque.String(tt.name) + cfg.AccessToken = configopaque.String(fromHeaders) + cfg.AccessTokenPassthrough = tt.accessTokenPassthrough + cfg.SendOTLPHistograms = tt.sendOTLPHistograms + sfxExp, err := NewFactory().CreateMetrics(context.Background(), exportertest.NewNopSettings(), cfg) + require.NoError(t, err) + ctx := context.Background() + if tt.inContext { + ctx = client.NewContext( + ctx, + client.Info{Metadata: client.NewMetadata( + map[string][]string{splunk.SFxAccessTokenHeader: {fromContext}}, + )}, + ) + } + require.NoError(t, sfxExp.Start(ctx, componenttest.NewNopHost())) + defer func() { + require.NoError(t, sfxExp.Shutdown(context.Background())) + }() + + err = sfxExp.ConsumeMetrics(ctx, tt.metrics) + + assert.NoError(t, err) + require.Eventually(t, func() bool { + receivedTokens.Lock() + defer receivedTokens.Unlock() + return len(tt.pushedTokens) == len(receivedTokens.tokens) + }, 1*time.Second, 10*time.Millisecond) + sort.Strings(tt.pushedTokens) + sort.Strings(receivedTokens.tokens) + assert.Equal(t, tt.pushedTokens, receivedTokens.tokens) + }) + } +} + func TestNewEventExporter(t *testing.T) { got, err := newEventExporter(nil, exportertest.NewNopSettings()) assert.EqualError(t, err, "nil config") @@ -812,6 +945,102 @@ func TestConsumeLogsDataWithAccessTokenPassthrough(t *testing.T) { } } +func TestConsumeLogsAccessTokenPassthrough(t *testing.T) { + fromHeaders := "AccessTokenFromClientHeaders" + fromLabels := "AccessTokenFromLabel" + fromContext := "AccessTokenFromContext" + + newLogData := func(includeToken bool) plog.Logs { + out := makeSampleResourceLogs() + makeSampleResourceLogs().ResourceLogs().At(0).CopyTo(out.ResourceLogs().AppendEmpty()) + + if includeToken { + out.ResourceLogs().At(0).Resource().Attributes().PutStr("com.splunk.signalfx.access_token", fromLabels) + out.ResourceLogs().At(1).Resource().Attributes().PutStr("com.splunk.signalfx.access_token", fromLabels) + } + return out + } + + tests := []struct { + name string + accessTokenPassthrough bool + includedInLogData bool + inContext bool + expectedToken string + }{ + { + name: "passthrough access token and not included in request context", + inContext: true, + accessTokenPassthrough: true, + includedInLogData: true, + expectedToken: fromContext, + }, + { + name: "passthrough access token and included in logs", + inContext: false, + accessTokenPassthrough: true, + includedInLogData: true, + expectedToken: fromLabels, + }, + { + name: "passthrough access token and not included in logs", + inContext: false, + accessTokenPassthrough: false, + includedInLogData: false, + expectedToken: fromHeaders, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + receivedTokens := struct { + sync.Mutex + tokens []string + }{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, tt.name, r.Header.Get("test_header_")) + receivedTokens.Lock() + receivedTokens.tokens = append(receivedTokens.tokens, r.Header.Get("x-sf-token")) + receivedTokens.Unlock() + w.WriteHeader(http.StatusAccepted) + })) + defer server.Close() + + factory := NewFactory() + cfg := factory.CreateDefaultConfig().(*Config) + cfg.IngestURL = server.URL + cfg.APIURL = server.URL + cfg.Headers = make(map[string]configopaque.String) + cfg.Headers["test_header_"] = configopaque.String(tt.name) + cfg.AccessToken = configopaque.String(fromHeaders) + cfg.AccessTokenPassthrough = tt.accessTokenPassthrough + sfxExp, err := NewFactory().CreateLogs(context.Background(), exportertest.NewNopSettings(), cfg) + require.NoError(t, err) + require.NoError(t, sfxExp.Start(context.Background(), componenttest.NewNopHost())) + defer func() { + require.NoError(t, sfxExp.Shutdown(context.Background())) + }() + + ctx := context.Background() + if tt.inContext { + ctx = client.NewContext( + ctx, + client.Info{Metadata: client.NewMetadata( + map[string][]string{splunk.SFxAccessTokenHeader: {"AccessTokenFromContext"}}, + )}, + ) + } + assert.NoError(t, sfxExp.ConsumeLogs(ctx, newLogData(tt.includedInLogData))) + + require.Eventually(t, func() bool { + receivedTokens.Lock() + defer receivedTokens.Unlock() + return len(receivedTokens.tokens) == 1 + }, 1*time.Second, 10*time.Millisecond) + assert.Equal(t, tt.expectedToken, receivedTokens.tokens[0]) + }) + } +} + func generateLargeDPBatch() pmetric.Metrics { md := pmetric.NewMetrics() md.ResourceMetrics().EnsureCapacity(6500)