Skip to content

Commit 397f228

Browse files
nprajileshalrex
and
alrex
authored
Converted TextMap propagator getter to a class and added keys method (open-telemetry#1196)
Co-authored-by: alrex <aboten@lightstep.com>
1 parent a3a75e3 commit 397f228

File tree

26 files changed

+198
-162
lines changed

26 files changed

+198
-162
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
("py:class", "ObjectProxy"),
103103
# TODO: Understand why sphinx is not able to find this local class
104104
("py:class", "opentelemetry.trace.propagation.textmap.TextMapPropagator",),
105+
("py:class", "opentelemetry.trace.propagation.textmap.DictGetter",),
105106
(
106107
"any",
107108
"opentelemetry.trace.propagation.textmap.TextMapPropagator.extract",

exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,23 @@ class DatadogFormat(TextMapPropagator):
3939

4040
def extract(
4141
self,
42-
get_from_carrier: Getter[TextMapPropagatorT],
42+
getter: Getter[TextMapPropagatorT],
4343
carrier: TextMapPropagatorT,
4444
context: typing.Optional[Context] = None,
4545
) -> Context:
4646
trace_id = extract_first_element(
47-
get_from_carrier(carrier, self.TRACE_ID_KEY)
47+
getter.get(carrier, self.TRACE_ID_KEY)
4848
)
4949

5050
span_id = extract_first_element(
51-
get_from_carrier(carrier, self.PARENT_ID_KEY)
51+
getter.get(carrier, self.PARENT_ID_KEY)
5252
)
5353

5454
sampled = extract_first_element(
55-
get_from_carrier(carrier, self.SAMPLING_PRIORITY_KEY)
55+
getter.get(carrier, self.SAMPLING_PRIORITY_KEY)
5656
)
5757

58-
origin = extract_first_element(
59-
get_from_carrier(carrier, self.ORIGIN_KEY)
60-
)
58+
origin = extract_first_element(getter.get(carrier, self.ORIGIN_KEY))
6159

6260
trace_flags = trace.TraceFlags()
6361
if sampled and int(sampled) in (

exporter/opentelemetry-exporter-datadog/tests/test_datadog_format.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818
from opentelemetry.exporter.datadog import constants, propagator
1919
from opentelemetry.sdk import trace
2020
from opentelemetry.trace import get_current_span, set_span_in_context
21+
from opentelemetry.trace.propagation.textmap import DictGetter
2122

2223
FORMAT = propagator.DatadogFormat()
2324

24-
25-
def get_as_list(dict_object, key):
26-
value = dict_object.get(key)
27-
return [value] if value is not None else []
25+
carrier_getter = DictGetter()
2826

2927

3028
class TestDatadogFormat(unittest.TestCase):
@@ -45,7 +43,7 @@ def test_malformed_headers(self):
4543
malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x"
4644
context = get_current_span(
4745
FORMAT.extract(
48-
get_as_list,
46+
carrier_getter,
4947
{
5048
malformed_trace_id_key: self.serialized_trace_id,
5149
malformed_parent_id_key: self.serialized_parent_id,
@@ -63,7 +61,7 @@ def test_missing_trace_id(self):
6361
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
6462
}
6563

66-
ctx = FORMAT.extract(get_as_list, carrier)
64+
ctx = FORMAT.extract(carrier_getter, carrier)
6765
span_context = get_current_span(ctx).get_span_context()
6866
self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)
6967

@@ -73,15 +71,15 @@ def test_missing_parent_id(self):
7371
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
7472
}
7573

76-
ctx = FORMAT.extract(get_as_list, carrier)
74+
ctx = FORMAT.extract(carrier_getter, carrier)
7775
span_context = get_current_span(ctx).get_span_context()
7876
self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID)
7977

8078
def test_context_propagation(self):
8179
"""Test the propagation of Datadog headers."""
8280
parent_span_context = get_current_span(
8381
FORMAT.extract(
84-
get_as_list,
82+
carrier_getter,
8583
{
8684
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
8785
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
@@ -138,7 +136,7 @@ def test_sampling_priority_auto_reject(self):
138136
"""Test sampling priority rejected."""
139137
parent_span_context = get_current_span(
140138
FORMAT.extract(
141-
get_as_list,
139+
carrier_getter,
142140
{
143141
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
144142
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,

instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,31 @@
2929
from opentelemetry import context, propagators, trace
3030
from opentelemetry.instrumentation.asgi.version import __version__ # noqa
3131
from opentelemetry.instrumentation.utils import http_status_to_status_code
32+
from opentelemetry.trace.propagation.textmap import DictGetter
3233
from opentelemetry.trace.status import Status, StatusCode
3334

3435

35-
def get_header_from_scope(scope: dict, header_name: str) -> typing.List[str]:
36-
"""Retrieve a HTTP header value from the ASGI scope.
36+
class CarrierGetter(DictGetter):
37+
def get(self, carrier: dict, key: str) -> typing.List[str]:
38+
"""Getter implementation to retrieve a HTTP header value from the ASGI
39+
scope.
3740
38-
Returns:
39-
A list with a single string with the header value if it exists, else an empty list.
40-
"""
41-
headers = scope.get("headers")
42-
return [
43-
value.decode("utf8")
44-
for (key, value) in headers
45-
if key.decode("utf8") == header_name
46-
]
41+
Args:
42+
carrier: ASGI scope object
43+
key: header name in scope
44+
Returns:
45+
A list with a single string with the header value if it exists,
46+
else an empty list.
47+
"""
48+
headers = carrier.get("headers")
49+
return [
50+
_value.decode("utf8")
51+
for (_key, _value) in headers
52+
if _key.decode("utf8") == key
53+
]
54+
55+
56+
carrier_getter = CarrierGetter()
4757

4858

4959
def collect_request_attributes(scope):
@@ -72,10 +82,10 @@ def collect_request_attributes(scope):
7282
http_method = scope.get("method")
7383
if http_method:
7484
result["http.method"] = http_method
75-
http_host_value = ",".join(get_header_from_scope(scope, "host"))
85+
http_host_value = ",".join(carrier_getter.get(scope, "host"))
7686
if http_host_value:
7787
result["http.server_name"] = http_host_value
78-
http_user_agent = get_header_from_scope(scope, "user-agent")
88+
http_user_agent = carrier_getter.get(scope, "user-agent")
7989
if len(http_user_agent) > 0:
8090
result["http.user_agent"] = http_user_agent[0]
8191

@@ -154,9 +164,7 @@ async def __call__(self, scope, receive, send):
154164
if scope["type"] not in ("http", "websocket"):
155165
return await self.app(scope, receive, send)
156166

157-
token = context.attach(
158-
propagators.extract(get_header_from_scope, scope)
159-
)
167+
token = context.attach(propagators.extract(carrier_getter, scope))
160168
span_name, additional_attributes = self.span_details_callback(scope)
161169

162170
try:

instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def add(x, y):
6767
from opentelemetry.instrumentation.celery import utils
6868
from opentelemetry.instrumentation.celery.version import __version__
6969
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
70-
from opentelemetry.trace.propagation import get_current_span
70+
from opentelemetry.trace.propagation.textmap import DictGetter
7171
from opentelemetry.trace.status import Status, StatusCode
7272

7373
logger = logging.getLogger(__name__)
@@ -84,6 +84,20 @@ def add(x, y):
8484
_MESSAGE_ID_ATTRIBUTE_NAME = "messaging.message_id"
8585

8686

87+
class CarrierGetter(DictGetter):
88+
def get(self, carrier, key):
89+
value = getattr(carrier, key, [])
90+
if isinstance(value, str) or not isinstance(value, Iterable):
91+
value = (value,)
92+
return value
93+
94+
def keys(self, carrier):
95+
return []
96+
97+
98+
carrier_getter = CarrierGetter()
99+
100+
87101
class CeleryInstrumentor(BaseInstrumentor):
88102
def _instrument(self, **kwargs):
89103
tracer_provider = kwargs.get("tracer_provider")
@@ -118,7 +132,7 @@ def _trace_prerun(self, *args, **kwargs):
118132
return
119133

120134
request = task.request
121-
tracectx = propagators.extract(carrier_extractor, request) or None
135+
tracectx = propagators.extract(carrier_getter, request) or None
122136

123137
logger.debug("prerun signal start task_id=%s", task_id)
124138

@@ -246,10 +260,3 @@ def _trace_retry(*args, **kwargs):
246260
# Use `str(reason)` instead of `reason.message` in case we get
247261
# something that isn't an `Exception`
248262
span.set_attribute(_TASK_RETRY_REASON_KEY, str(reason))
249-
250-
251-
def carrier_extractor(carrier, key):
252-
value = getattr(carrier, key, [])
253-
if isinstance(value, str) or not isinstance(value, Iterable):
254-
value = (value,)
255-
return value

instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from opentelemetry.instrumentation.utils import extract_attributes_from_object
2424
from opentelemetry.instrumentation.wsgi import (
2525
add_response_attributes,
26+
carrier_getter,
2627
collect_request_attributes,
27-
get_header_from_environ,
2828
)
2929
from opentelemetry.propagators import extract
3030
from opentelemetry.trace import SpanKind, get_tracer
@@ -125,7 +125,7 @@ def process_request(self, request):
125125

126126
environ = request.META
127127

128-
token = attach(extract(get_header_from_environ, environ))
128+
token = attach(extract(carrier_getter, environ))
129129

130130
tracer = get_tracer(__name__, __version__)
131131

instrumentation/opentelemetry-instrumentation-falcon/src/opentelemetry/instrumentation/falcon/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __call__(self, env, start_response):
115115
start_time = time_ns()
116116

117117
token = context.attach(
118-
propagators.extract(otel_wsgi.get_header_from_environ, env)
118+
propagators.extract(otel_wsgi.carrier_getter, env)
119119
)
120120
span = self._tracer.start_span(
121121
otel_wsgi.get_default_span_name(env),

instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _before_request():
118118
if span_name is None:
119119
span_name = otel_wsgi.get_default_span_name(environ)
120120
token = context.attach(
121-
propagators.extract(otel_wsgi.get_header_from_environ, environ)
121+
propagators.extract(otel_wsgi.carrier_getter, environ)
122122
)
123123

124124
tracer = trace.get_tracer(__name__, __version__)

instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323

2424
import logging
2525
from contextlib import contextmanager
26-
from typing import List
2726

2827
import grpc
2928

3029
from opentelemetry import propagators, trace
3130
from opentelemetry.context import attach, detach
31+
from opentelemetry.trace.propagation.textmap import DictGetter
3232
from opentelemetry.trace.status import Status, StatusCode
3333

3434
logger = logging.getLogger(__name__)
@@ -163,18 +163,14 @@ class OpenTelemetryServerInterceptor(grpc.ServerInterceptor):
163163

164164
def __init__(self, tracer):
165165
self._tracer = tracer
166+
self._carrier_getter = DictGetter()
166167

167168
@contextmanager
168169
def _set_remote_context(self, servicer_context):
169170
metadata = servicer_context.invocation_metadata()
170171
if metadata:
171172
md_dict = {md.key: md.value for md in metadata}
172-
173-
def get_from_grpc_metadata(metadata, key) -> List[str]:
174-
return [md_dict[key]] if key in md_dict else []
175-
176-
# Update the context with the traceparent from the RPC metadata.
177-
ctx = propagators.extract(get_from_grpc_metadata, metadata)
173+
ctx = propagators.extract(self._carrier_getter, md_dict)
178174
token = attach(ctx)
179175
try:
180176
yield

instrumentation/opentelemetry-instrumentation-opentracing-shim/src/opentelemetry/instrumentation/opentracing_shim/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
get_current_span,
113113
set_span_in_context,
114114
)
115+
from opentelemetry.trace.propagation.textmap import DictGetter
115116
from opentelemetry.util.types import Attributes
116117

117118
ValueT = TypeVar("ValueT", int, float, bool, str)
@@ -527,6 +528,7 @@ def __init__(self, tracer: OtelTracer):
527528
Format.TEXT_MAP,
528529
Format.HTTP_HEADERS,
529530
)
531+
self._carrier_getter = DictGetter()
530532

531533
def unwrap(self):
532534
"""Returns the :class:`opentelemetry.trace.Tracer` object that is
@@ -710,12 +712,8 @@ def extract(self, format: object, carrier: object):
710712
if format not in self._supported_formats:
711713
raise UnsupportedFormatException
712714

713-
def get_as_list(dict_object, key):
714-
value = dict_object.get(key)
715-
return [value] if value is not None else []
716-
717715
propagator = propagators.get_global_textmap()
718-
ctx = propagator.extract(get_as_list, carrier)
716+
ctx = propagator.extract(self._carrier_getter, carrier)
719717
span = get_current_span(ctx)
720718
if span is not None:
721719
otel_context = span.get_span_context()

instrumentation/opentelemetry-instrumentation-pyramid/src/opentelemetry/instrumentation/pyramid/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _before_traversal(event):
7070
start_time = environ.get(_ENVIRON_STARTTIME_KEY)
7171

7272
token = context.attach(
73-
propagators.extract(otel_wsgi.get_header_from_environ, environ)
73+
propagators.extract(otel_wsgi.carrier_getter, environ)
7474
)
7575
tracer = trace.get_tracer(__name__, __version__)
7676

instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def get(self):
5454
http_status_to_status_code,
5555
unwrap,
5656
)
57+
from opentelemetry.trace.propagation.textmap import DictGetter
5758
from opentelemetry.trace.status import Status
5859
from opentelemetry.util import ExcludeList, time_ns
5960

@@ -84,6 +85,8 @@ def get_traced_request_attrs():
8485
_excluded_urls = get_excluded_urls()
8586
_traced_attrs = get_traced_request_attrs()
8687

88+
carrier_getter = DictGetter()
89+
8790

8891
class TornadoInstrumentor(BaseInstrumentor):
8992
patched_handlers = []
@@ -185,13 +188,6 @@ def _log_exception(tracer, func, handler, args, kwargs):
185188
return func(*args, **kwargs)
186189

187190

188-
def _get_header_from_request_headers(
189-
headers: dict, header_name: str
190-
) -> typing.List[str]:
191-
header = headers.get(header_name)
192-
return [header] if header else []
193-
194-
195191
def _get_attributes_from_request(request):
196192
attrs = {
197193
"component": "tornado",
@@ -218,9 +214,7 @@ def _get_operation_name(handler, request):
218214

219215
def _start_span(tracer, handler, start_time) -> _TraceContext:
220216
token = context.attach(
221-
propagators.extract(
222-
_get_header_from_request_headers, handler.request.headers,
223-
)
217+
propagators.extract(carrier_getter, handler.request.headers,)
224218
)
225219

226220
span = tracer.start_span(

0 commit comments

Comments
 (0)