Skip to content

Commit 1d39f7f

Browse files
authored
Fix TraceState to adhere to specs (open-telemetry#1502)
1 parent c750109 commit 1d39f7f

File tree

10 files changed

+376
-96
lines changed

10 files changed

+376
-96
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ env:
1010
# Otherwise, set variable to the commit of your branch on
1111
# opentelemetry-python-contrib which is compatible with these Core repo
1212
# changes.
13-
CONTRIB_REPO_SHA: 32cac7a9ff6c831aa0e9514bb38c430fce819141
13+
CONTRIB_REPO_SHA: 1e319dbaf21df7573f15f35773b8272579dd1030
1414

1515
jobs:
1616
build:

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6666
([#1535](https://github.com/open-telemetry/opentelemetry-python/pull/1535))
6767
- `opentelemetry-sdk` Remove rate property setter from TraceIdRatioBasedSampler
6868
([#1536](https://github.com/open-telemetry/opentelemetry-python/pull/1536))
69+
- Fix TraceState to adhere to specs
70+
([#1502](https://github.com/open-telemetry/opentelemetry-python/pull/1502))
6971

7072
### Removed
7173
- `opentelemetry-api` Remove ThreadLocalRuntimeContext since python3.4 is not supported.

exporter/opentelemetry-exporter-opencensus/tests/test_otcollector_trace_exporter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_translate_to_collector(self):
9494
span_id,
9595
is_remote=False,
9696
trace_flags=TraceFlags(TraceFlags.SAMPLED),
97-
trace_state=trace_api.TraceState({"testKey": "testValue"}),
97+
trace_state=trace_api.TraceState([("testkey", "testvalue")]),
9898
)
9999
parent_span_context = trace_api.SpanContext(
100100
trace_id, parent_id, is_remote=False
@@ -200,9 +200,9 @@ def test_translate_to_collector(self):
200200
)
201201
self.assertEqual(output_spans[0].status.message, "test description")
202202
self.assertEqual(len(output_spans[0].tracestate.entries), 1)
203-
self.assertEqual(output_spans[0].tracestate.entries[0].key, "testKey")
203+
self.assertEqual(output_spans[0].tracestate.entries[0].key, "testkey")
204204
self.assertEqual(
205-
output_spans[0].tracestate.entries[0].value, "testValue"
205+
output_spans[0].tracestate.entries[0].value, "testvalue"
206206
)
207207

208208
self.assertEqual(

opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py

Lines changed: 3 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,7 @@
1818
import opentelemetry.trace as trace
1919
from opentelemetry.context.context import Context
2020
from opentelemetry.trace.propagation import textmap
21-
22-
# Keys and values are strings of up to 256 printable US-ASCII characters.
23-
# Implementations should conform to the `W3C Trace Context - Tracestate`_
24-
# spec, which describes additional restrictions on valid field values.
25-
#
26-
# .. _W3C Trace Context - Tracestate:
27-
# https://www.w3.org/TR/trace-context/#tracestate-field
28-
29-
_KEY_WITHOUT_VENDOR_FORMAT = r"[a-z][_0-9a-z\-\*\/]{0,255}"
30-
_KEY_WITH_VENDOR_FORMAT = (
31-
r"[a-z0-9][_0-9a-z\-\*\/]{0,240}@[a-z][_0-9a-z\-\*\/]{0,13}"
32-
)
33-
34-
_KEY_FORMAT = _KEY_WITHOUT_VENDOR_FORMAT + "|" + _KEY_WITH_VENDOR_FORMAT
35-
_VALUE_FORMAT = (
36-
r"[\x20-\x2b\x2d-\x3c\x3e-\x7e]{0,255}[\x21-\x2b\x2d-\x3c\x3e-\x7e]"
37-
)
38-
39-
_DELIMITER_FORMAT = "[ \t]*,[ \t]*"
40-
_MEMBER_FORMAT = "({})(=)({})[ \t]*".format(_KEY_FORMAT, _VALUE_FORMAT)
41-
42-
_DELIMITER_FORMAT_RE = re.compile(_DELIMITER_FORMAT)
43-
_MEMBER_FORMAT_RE = re.compile(_MEMBER_FORMAT)
44-
45-
_TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS = 32
21+
from opentelemetry.trace.span import TraceState
4622

4723

4824
class TraceContextTextMapPropagator(textmap.TextMapPropagator):
@@ -94,7 +70,7 @@ def extract(
9470
if tracestate_headers is None:
9571
tracestate = None
9672
else:
97-
tracestate = _parse_tracestate(tracestate_headers)
73+
tracestate = TraceState.from_header(tracestate_headers)
9874

9975
span_context = trace.SpanContext(
10076
trace_id=int(trace_id, 16),
@@ -130,7 +106,7 @@ def inject(
130106
carrier, self._TRACEPARENT_HEADER_NAME, traceparent_string
131107
)
132108
if span_context.trace_state:
133-
tracestate_string = _format_tracestate(span_context.trace_state)
109+
tracestate_string = span_context.trace_state.to_header()
134110
set_in_carrier(
135111
carrier, self._TRACESTATE_HEADER_NAME, tracestate_string
136112
)
@@ -143,57 +119,3 @@ def fields(self) -> typing.Set[str]:
143119
`opentelemetry.trace.propagation.textmap.TextMapPropagator.fields`
144120
"""
145121
return {self._TRACEPARENT_HEADER_NAME, self._TRACESTATE_HEADER_NAME}
146-
147-
148-
def _parse_tracestate(header_list: typing.List[str]) -> trace.TraceState:
149-
"""Parse one or more w3c tracestate header into a TraceState.
150-
151-
Args:
152-
string: the value of the tracestate header.
153-
154-
Returns:
155-
A valid TraceState that contains values extracted from
156-
the tracestate header.
157-
158-
If the format of one headers is illegal, all values will
159-
be discarded and an empty tracestate will be returned.
160-
161-
If the number of keys is beyond the maximum, all values
162-
will be discarded and an empty tracestate will be returned.
163-
"""
164-
tracestate = trace.TraceState()
165-
value_count = 0
166-
for header in header_list:
167-
for member in re.split(_DELIMITER_FORMAT_RE, header):
168-
# empty members are valid, but no need to process further.
169-
if not member:
170-
continue
171-
match = _MEMBER_FORMAT_RE.fullmatch(member)
172-
if not match:
173-
# TODO: log this?
174-
return trace.TraceState()
175-
key, _eq, value = match.groups()
176-
if key in tracestate: # pylint:disable=E1135
177-
# duplicate keys are not legal in
178-
# the header, so we will remove
179-
return trace.TraceState()
180-
# typing.Dict's update is not recognized by pylint:
181-
# https://github.com/PyCQA/pylint/issues/2420
182-
tracestate[key] = value # pylint:disable=E1137
183-
value_count += 1
184-
if value_count > _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS:
185-
return trace.TraceState()
186-
return tracestate
187-
188-
189-
def _format_tracestate(tracestate: trace.TraceState) -> str:
190-
"""Parse a w3c tracestate header into a TraceState.
191-
192-
Args:
193-
tracestate: the tracestate header to write
194-
195-
Returns:
196-
A string that adheres to the w3c tracestate
197-
header format.
198-
"""
199-
return ",".join(key + "=" + value for key, value in tracestate.items())

opentelemetry-api/src/opentelemetry/trace/span.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import abc
22
import logging
3+
import re
34
import types as python_types
45
import typing
6+
from collections import OrderedDict
57

68
from opentelemetry.trace.status import Status
79
from opentelemetry.util import types
10+
from opentelemetry.util.tracestate import (
11+
_DELIMITER_PATTERN,
12+
_MEMBER_PATTERN,
13+
_TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS,
14+
_is_valid_pair,
15+
)
816

917
_logger = logging.getLogger(__name__)
1018

@@ -135,7 +143,7 @@ def sampled(self) -> bool:
135143
DEFAULT_TRACE_OPTIONS = TraceFlags.get_default()
136144

137145

138-
class TraceState(typing.Dict[str, str]):
146+
class TraceState(typing.Mapping[str, str]):
139147
"""A list of key-value pairs representing vendor-specific trace info.
140148
141149
Keys and values are strings of up to 256 printable US-ASCII characters.
@@ -146,10 +154,186 @@ class TraceState(typing.Dict[str, str]):
146154
https://www.w3.org/TR/trace-context/#tracestate-field
147155
"""
148156

157+
def __init__(
158+
self,
159+
entries: typing.Optional[
160+
typing.Sequence[typing.Tuple[str, str]]
161+
] = None,
162+
) -> None:
163+
self._dict = OrderedDict() # type: OrderedDict[str, str]
164+
if entries is None:
165+
return
166+
if len(entries) > _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS:
167+
_logger.warning(
168+
"There can't be more than %s key/value pairs.",
169+
_TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS,
170+
)
171+
return
172+
173+
for key, value in entries:
174+
if _is_valid_pair(key, value):
175+
if key in self._dict:
176+
_logger.warning("Duplicate key: %s found.", key)
177+
continue
178+
self._dict[key] = value
179+
else:
180+
_logger.warning(
181+
"Invalid key/value pair (%s, %s) found.", key, value
182+
)
183+
184+
def __getitem__(self, key: str) -> typing.Optional[str]: # type: ignore
185+
return self._dict.get(key)
186+
187+
def __iter__(self) -> typing.Iterator[str]:
188+
return iter(self._dict)
189+
190+
def __len__(self) -> int:
191+
return len(self._dict)
192+
193+
def __repr__(self) -> str:
194+
pairs = [
195+
"{key=%s, value=%s}" % (key, value)
196+
for key, value in self._dict.items()
197+
]
198+
return str(pairs)
199+
200+
def add(self, key: str, value: str) -> "TraceState":
201+
"""Adds a key-value pair to tracestate. The provided pair should
202+
adhere to w3c tracestate identifiers format.
203+
204+
Args:
205+
key: A valid tracestate key to add
206+
value: A valid tracestate value to add
207+
208+
Returns:
209+
A new TraceState with the modifications applied.
210+
211+
If the provided key-value pair is invalid or results in tracestate
212+
that violates tracecontext specification, they are discarded and
213+
same tracestate will be returned.
214+
"""
215+
if not _is_valid_pair(key, value):
216+
_logger.warning(
217+
"Invalid key/value pair (%s, %s) found.", key, value
218+
)
219+
return self
220+
# There can be a maximum of 32 pairs
221+
if len(self) >= _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS:
222+
_logger.warning("There can't be more 32 key/value pairs.")
223+
return self
224+
# Duplicate entries are not allowed
225+
if key in self._dict:
226+
_logger.warning("The provided key %s already exists.", key)
227+
return self
228+
new_state = [(key, value)] + list(self._dict.items())
229+
return TraceState(new_state)
230+
231+
def update(self, key: str, value: str) -> "TraceState":
232+
"""Updates a key-value pair in tracestate. The provided pair should
233+
adhere to w3c tracestate identifiers format.
234+
235+
Args:
236+
key: A valid tracestate key to update
237+
value: A valid tracestate value to update for key
238+
239+
Returns:
240+
A new TraceState with the modifications applied.
241+
242+
If the provided key-value pair is invalid or results in tracestate
243+
that violates tracecontext specification, they are discarded and
244+
same tracestate will be returned.
245+
"""
246+
if not _is_valid_pair(key, value):
247+
_logger.warning(
248+
"Invalid key/value pair (%s, %s) found.", key, value
249+
)
250+
return self
251+
prev_state = self._dict.copy()
252+
prev_state[key] = value
253+
prev_state.move_to_end(key, last=False)
254+
new_state = list(prev_state.items())
255+
return TraceState(new_state)
256+
257+
def delete(self, key: str) -> "TraceState":
258+
"""Deletes a key-value from tracestate.
259+
260+
Args:
261+
key: A valid tracestate key to remove key-value pair from tracestate
262+
263+
Returns:
264+
A new TraceState with the modifications applied.
265+
266+
If the provided key-value pair is invalid or results in tracestate
267+
that violates tracecontext specification, they are discarded and
268+
same tracestate will be returned.
269+
"""
270+
if key not in self._dict:
271+
_logger.warning("The provided key %s doesn't exist.", key)
272+
return self
273+
prev_state = self._dict.copy()
274+
prev_state.pop(key)
275+
new_state = list(prev_state.items())
276+
return TraceState(new_state)
277+
278+
def to_header(self) -> str:
279+
"""Creates a w3c tracestate header from a TraceState.
280+
281+
Returns:
282+
A string that adheres to the w3c tracestate
283+
header format.
284+
"""
285+
return ",".join(key + "=" + value for key, value in self._dict.items())
286+
287+
@classmethod
288+
def from_header(cls, header_list: typing.List[str]) -> "TraceState":
289+
"""Parses one or more w3c tracestate header into a TraceState.
290+
291+
Args:
292+
header_list: one or more w3c tracestate headers.
293+
294+
Returns:
295+
A valid TraceState that contains values extracted from
296+
the tracestate header.
297+
298+
If the format of one headers is illegal, all values will
299+
be discarded and an empty tracestate will be returned.
300+
301+
If the number of keys is beyond the maximum, all values
302+
will be discarded and an empty tracestate will be returned.
303+
"""
304+
pairs = OrderedDict()
305+
for header in header_list:
306+
for member in re.split(_DELIMITER_PATTERN, header):
307+
# empty members are valid, but no need to process further.
308+
if not member:
309+
continue
310+
match = _MEMBER_PATTERN.fullmatch(member)
311+
if not match:
312+
_logger.warning(
313+
"Member doesn't match the w3c identifiers format %s",
314+
member,
315+
)
316+
return cls()
317+
key, _eq, value = match.groups()
318+
# duplicate keys are not legal in header
319+
if key in pairs:
320+
return cls()
321+
pairs[key] = value
322+
return cls(list(pairs.items()))
323+
149324
@classmethod
150325
def get_default(cls) -> "TraceState":
151326
return cls()
152327

328+
def keys(self) -> typing.KeysView[str]:
329+
return self._dict.keys()
330+
331+
def items(self) -> typing.ItemsView[str, str]:
332+
return self._dict.items()
333+
334+
def values(self) -> typing.ValuesView[str]:
335+
return self._dict.values()
336+
153337

154338
DEFAULT_TRACE_STATE = TraceState.get_default()
155339

0 commit comments

Comments
 (0)