diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/tracing.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/tracing.py index 1e4dd36511bd..fcb7a42a2cbb 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/tracing.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/tracing.py @@ -8,6 +8,7 @@ from enum import Enum import logging from typing import ( + Any, Dict, Iterable, Iterator, @@ -22,6 +23,20 @@ from azure.core import CaseInsensitiveEnumMeta from azure.core.settings import settings from azure.core.tracing import SpanKind, Link +from azure.servicebus._version import VERSION + +try: + from azure.core.instrumentation import get_tracer + TRACER = get_tracer( + library_name="azure-servicebus", + library_version=VERSION, + schema_url="https://opentelemetry.io/schemas/1.23.1", + attributes={ + "az.namespace": "Microsoft.ServiceBus", + }, + ) +except ImportError: + TRACER = None if TYPE_CHECKING: try: @@ -62,6 +77,7 @@ _LOGGER = logging.getLogger(__name__) + class TraceAttributes: TRACE_NAMESPACE_ATTRIBUTE = "az.namespace" TRACE_NAMESPACE = "Microsoft.ServiceBus" @@ -74,9 +90,6 @@ class TraceAttributes: TRACE_MESSAGING_OPERATION_ATTRIBUTE = "messaging.operation" TRACE_MESSAGING_BATCH_COUNT_ATTRIBUTE = "messaging.batch.message_count" - LEGACY_TRACE_MESSAGE_BUS_DESTINATION_ATTRIBUTE = "message_bus.destination" - LEGACY_TRACE_PEER_ADDRESS_ATTRIBUTE = "peer.address" - class TraceOperationTypes(str, Enum, metaclass=CaseInsensitiveEnumMeta): PUBLISH = "publish" @@ -85,8 +98,13 @@ class TraceOperationTypes(str, Enum, metaclass=CaseInsensitiveEnumMeta): def is_tracing_enabled(): - span_impl_type = settings.tracing_implementation() - return span_impl_type is not None + if TRACER is None: + # The version of azure-core installed does not support native tracing. Just check + # for the plugin. + span_impl_type = settings.tracing_implementation() + return span_impl_type is not None + # Otherwise, can just check the tracing setting. + return settings.tracing_enabled() @contextmanager @@ -96,6 +114,7 @@ def send_trace_context_manager( links: Optional[List[Link]] = None, ) -> Iterator[None]: """Tracing for sending messages. + :param sender: The sender that is sending the message. :type sender: ~azure.servicebus.ServiceBusSender or ~azure.servicebus.aio.ServiceBusSenderAsync :param span_name: The name of the tracing span. @@ -106,11 +125,19 @@ def send_trace_context_manager( :rtype: iterator """ span_impl_type: Optional[Type[AbstractSpan]] = settings.tracing_implementation() + links = links or [] if span_impl_type is not None: - links = links or [] with span_impl_type(name=span_name, kind=SpanKind.CLIENT, links=links) as span: - add_span_attributes(span, TraceOperationTypes.PUBLISH, sender, message_count=len(links)) + add_plugin_span_attributes(span, TraceOperationTypes.PUBLISH, sender, message_count=len(links)) + yield + elif TRACER is not None: + if settings.tracing_enabled(): + with TRACER.start_as_current_span(span_name, kind=SpanKind.CLIENT, links=links) as span: + attributes = get_span_attributes(TraceOperationTypes.PUBLISH, sender, message_count=len(links)) + span.set_attributes(attributes) + yield + else: yield else: yield @@ -124,6 +151,7 @@ def receive_trace_context_manager( start_time: Optional[int] = None, ) -> Iterator[None]: """Tracing for receiving messages. + :param receiver: The receiver that is receiving the message. :type receiver: ~azure.servicebus.ServiceBusReceiver or ~azure.servicebus.aio.ServiceBusReceiverAsync :param span_name: The name of the tracing span. @@ -135,11 +163,30 @@ def receive_trace_context_manager( :return: An iterator that yields the tracing span. :rtype: iterator """ + links = links or [] span_impl_type: Optional[Type[AbstractSpan]] = settings.tracing_implementation() if span_impl_type is not None: - links = links or [] with span_impl_type(name=span_name, kind=SpanKind.CLIENT, links=links, start_time=start_time) as span: - add_span_attributes(span, TraceOperationTypes.RECEIVE, receiver, message_count=len(links)) + add_plugin_span_attributes(span, TraceOperationTypes.RECEIVE, receiver, message_count=len(links)) + yield + elif TRACER is not None: + if settings.tracing_enabled(): + # Depending on the azure-core version, start_as_current_span may or may not support start_time as a + # keyword argument. Handle both cases. + try: + with TRACER.start_as_current_span( # type: ignore[call-arg] # pylint: disable=unexpected-keyword-arg + span_name, kind=SpanKind.CLIENT, start_time=start_time, links=links + ) as span: + attributes = get_span_attributes(TraceOperationTypes.RECEIVE, receiver, message_count=len(links)) + span.set_attributes(attributes) + yield + except TypeError: + # If start_time is not supported, just call without it. + with TRACER.start_as_current_span(span_name, kind=SpanKind.CLIENT, links=links) as span: + attributes = get_span_attributes(TraceOperationTypes.RECEIVE, receiver, message_count=len(links)) + span.set_attributes(attributes) + yield + else: yield else: yield @@ -150,6 +197,7 @@ def settle_trace_context_manager( receiver: Union[ServiceBusReceiver, ServiceBusReceiverAsync], operation: str, links: Optional[List[Link]] = None ): """Tracing for settling messages. + :param receiver: The receiver that is settling the message. :type receiver: ~azure.servicebus.ServiceBusReceiver or ~azure.servicebus.aio.ServiceBusReceiver :param operation: The operation that is being performed on the message. @@ -163,10 +211,27 @@ def settle_trace_context_manager( if span_impl_type is not None: links = links or [] with span_impl_type(name=f"ServiceBus.{operation}", kind=SpanKind.CLIENT, links=links) as span: - add_span_attributes(span, TraceOperationTypes.SETTLE, receiver) + add_plugin_span_attributes(span, TraceOperationTypes.SETTLE, receiver) yield - else: - yield + elif TRACER is not None: + if settings.tracing_enabled(): + with TRACER.start_as_current_span(f"ServiceBus.{operation}", kind=SpanKind.CLIENT, links=links) as span: + attributes = get_span_attributes(TraceOperationTypes.SETTLE, receiver) + span.set_attributes(attributes) + yield + else: + yield + + +def _update_message_with_trace_context(message, amqp_transport, context): + if "traceparent" in context: + message = amqp_transport.update_message_app_properties( + message, TRACE_DIAGNOSTIC_ID_PROPERTY, context["traceparent"] + ) + message = amqp_transport.update_message_app_properties(message, TRACE_PARENT_PROPERTY, context["traceparent"]) + if "tracestate" in context: + message = amqp_transport.update_message_app_properties(message, TRACE_STATE_PROPERTY, context["tracestate"]) + return message def trace_message( @@ -191,20 +256,8 @@ def trace_message( span_impl_type: Optional[Type[AbstractSpan]] = settings.tracing_implementation() if span_impl_type is not None: with span_impl_type(name=SPAN_NAME_MESSAGE, kind=SpanKind.PRODUCER) as message_span: - headers = message_span.to_header() - - if "traceparent" in headers: - message = amqp_transport.update_message_app_properties( - message, TRACE_DIAGNOSTIC_ID_PROPERTY, headers["traceparent"] - ) - message = amqp_transport.update_message_app_properties( - message, TRACE_PARENT_PROPERTY, headers["traceparent"] - ) - - if "tracestate" in headers: - message = amqp_transport.update_message_app_properties( - message, TRACE_STATE_PROPERTY, headers["tracestate"] - ) + context = message_span.to_header() + message = _update_message_with_trace_context(message, amqp_transport, context) message_span.add_attribute(TraceAttributes.TRACE_NAMESPACE_ATTRIBUTE, TraceAttributes.TRACE_NAMESPACE) message_span.add_attribute( @@ -215,6 +268,17 @@ def trace_message( for key, value in additional_attributes.items(): if value is not None: message_span.add_attribute(key, value) + elif TRACER is not None: + if settings.tracing_enabled(): + with TRACER.start_as_current_span(SPAN_NAME_MESSAGE, kind=SpanKind.PRODUCER) as message_span: + trace_context = TRACER.get_trace_context() + message = _update_message_with_trace_context(message, amqp_transport, trace_context) + attributes = { + TraceAttributes.TRACE_NAMESPACE_ATTRIBUTE: TraceAttributes.TRACE_NAMESPACE, + TraceAttributes.TRACE_MESSAGING_SYSTEM_ATTRIBUTE: TraceAttributes.TRACE_MESSAGING_SYSTEM, + **(additional_attributes or {}), + } + message_span.set_attributes(attributes) except Exception as exp: # pylint:disable=broad-except _LOGGER.warning("trace_message had an exception %r", exp) @@ -226,11 +290,7 @@ def get_receive_links(messages: Union[ServiceBusReceivedMessage, Iterable[Servic if not is_tracing_enabled(): return [] - trace_messages = ( - messages - if isinstance(messages, Iterable) - else (messages,) - ) + trace_messages = messages if isinstance(messages, Iterable) else (messages,) links = [] try: @@ -277,7 +337,7 @@ def get_span_links_from_batch(batch: ServiceBusMessageBatch) -> List[Link]: return links -def get_span_link_from_message(message: Union[uamqp_Message, pyamqp_Message, ServiceBusMessage]) -> Optional[Link]: +def get_span_link_from_message(message: Any) -> Optional[Link]: """Create a span link from a message. This will extract the traceparent and tracestate from the message application properties and create span links @@ -309,7 +369,7 @@ def get_span_link_from_message(message: Union[uamqp_Message, pyamqp_Message, Ser return Link(headers) -def add_span_attributes( +def add_plugin_span_attributes( span: AbstractSpan, operation_type: TraceOperationTypes, handler: Union[BaseHandler, BaseHandlerAsync], @@ -322,25 +382,35 @@ def add_span_attributes( ~azure.servicebus.aio._base_handler_async.BaseHandlerAsync handler: The handler that is performing the operation. :param int message_count: The number of messages being sent or received. """ + attributes = get_span_attributes(operation_type, handler, message_count) + for key, value in attributes.items(): + if value is not None: + span.add_attribute(key, value) - span.add_attribute(TraceAttributes.TRACE_NAMESPACE_ATTRIBUTE, TraceAttributes.TRACE_NAMESPACE) - span.add_attribute(TraceAttributes.TRACE_MESSAGING_SYSTEM_ATTRIBUTE, TraceAttributes.TRACE_MESSAGING_SYSTEM) - span.add_attribute(TraceAttributes.TRACE_MESSAGING_OPERATION_ATTRIBUTE, operation_type) - - if message_count > 1: - span.add_attribute(TraceAttributes.TRACE_MESSAGING_BATCH_COUNT_ATTRIBUTE, message_count) - if operation_type in (TraceOperationTypes.PUBLISH, TraceOperationTypes.RECEIVE): - # Maintain legacy attributes for backwards compatibility. - span.add_attribute( - TraceAttributes.LEGACY_TRACE_MESSAGE_BUS_DESTINATION_ATTRIBUTE, - handler._entity_name, # pylint: disable=protected-access - ) - span.add_attribute(TraceAttributes.LEGACY_TRACE_PEER_ADDRESS_ATTRIBUTE, handler.fully_qualified_namespace) +def get_span_attributes( + operation_type: TraceOperationTypes, + handler: Union[BaseHandler, BaseHandlerAsync], + message_count: int = 0, +) -> dict: + """Return a dict of attributes for a span based on the operation type. - elif operation_type == TraceOperationTypes.SETTLE: - span.add_attribute(TraceAttributes.TRACE_NET_PEER_NAME_ATTRIBUTE, handler.fully_qualified_namespace) - span.add_attribute( - TraceAttributes.TRACE_MESSAGING_DESTINATION_ATTRIBUTE, - handler._entity_name, # pylint: disable=protected-access - ) + :param TraceOperationTypes operation_type: The operation type. + :param ~azure.servicebus._base_handler.BaseHandler or + ~azure.servicebus.aio._base_handler_async.BaseHandlerAsync handler: The handler that is performing the operation. + :param int message_count: The number of messages being sent or received. + :return: Dictionary of span attributes. + :rtype: dict + """ + attributes: Dict[str, Any] = { + TraceAttributes.TRACE_NAMESPACE_ATTRIBUTE: TraceAttributes.TRACE_NAMESPACE, + TraceAttributes.TRACE_MESSAGING_SYSTEM_ATTRIBUTE: TraceAttributes.TRACE_MESSAGING_SYSTEM, + TraceAttributes.TRACE_MESSAGING_OPERATION_ATTRIBUTE: operation_type, + } + if message_count > 1: + attributes[TraceAttributes.TRACE_MESSAGING_BATCH_COUNT_ATTRIBUTE] = message_count + attributes[TraceAttributes.TRACE_NET_PEER_NAME_ATTRIBUTE] = handler.fully_qualified_namespace + attributes[TraceAttributes.TRACE_MESSAGING_DESTINATION_ATTRIBUTE] = ( + handler._entity_name # pylint: disable=protected-access + ) + return attributes diff --git a/sdk/servicebus/azure-servicebus/dev_requirements.txt b/sdk/servicebus/azure-servicebus/dev_requirements.txt index da4be168e3aa..179263db8e11 100644 --- a/sdk/servicebus/azure-servicebus/dev_requirements.txt +++ b/sdk/servicebus/azure-servicebus/dev_requirements.txt @@ -1,7 +1,8 @@ --e ../../core/azure-core +../../core/azure-core azure-identity~=1.17.0 -e ../../../tools/azure-sdk-tools azure-mgmt-servicebus~=8.0.0 aiohttp>=3.0 websocket-client -azure-mgmt-resource<=16.0.0 \ No newline at end of file +azure-mgmt-resource<=16.0.0 +opentelemetry-sdk~=1.26 diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_tracing_live_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_tracing_live_async.py new file mode 100644 index 000000000000..ccddd2526164 --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_tracing_live_async.py @@ -0,0 +1,166 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest + +from azure.servicebus import ServiceBusMessage +from azure.servicebus.aio import ServiceBusClient + +from test_tracing_live import ServiceBusTracingTestBase + + +class TestServiceBusTracingAsync(ServiceBusTracingTestBase): + + @pytest.mark.asyncio + @pytest.mark.live_test_only + async def test_servicebus_client_tracing_queue_async(self, config, tracing_helper): + + connection_string = config["servicebus_connection_string"] + queue_name = config["servicebus_queue_name"] + client = ServiceBusClient.from_connection_string(connection_string) + + with tracing_helper.tracer.start_as_current_span(name="root"): + async with client.get_queue_sender(queue_name) as sender: + + # Sending a single message + await sender.send_messages(ServiceBusMessage("Test foo message")) + + # Sending a batch of messages + message_batch = await sender.create_message_batch() + message_batch.add_message(ServiceBusMessage("First batch foo message")) + message_batch.add_message(ServiceBusMessage("Second batch foo message")) + await sender.send_messages(message_batch) + + + finished_spans = tracing_helper.exporter.get_finished_spans() + server_address = sender.fully_qualified_namespace + + # We expect 5 spans to have finished: 2 send spans, and 3 message spans. + assert len(finished_spans) == 5 + + # Verify the spans from the first send. + self._verify_message(span=finished_spans[0], dest=queue_name, server_address=server_address) + self._verify_send(span=finished_spans[1], dest=queue_name, server_address=server_address, message_count=1) + + # Verify span links from single send. + link = finished_spans[1].links[0] + assert link.context.span_id == finished_spans[0].context.span_id + assert link.context.trace_id == finished_spans[0].context.trace_id + + # Verify the spans from the second send. + self._verify_message(span=finished_spans[2], dest=queue_name, server_address=server_address) + self._verify_message(span=finished_spans[3], dest=queue_name, server_address=server_address) + self._verify_send(span=finished_spans[4], dest=queue_name, server_address=server_address, message_count=2) + + # Verify span links from batch send. + assert len(finished_spans[4].links) == 2 + link = finished_spans[4].links[0] + assert link.context.span_id == finished_spans[2].context.span_id + assert link.context.trace_id == finished_spans[2].context.trace_id + + link = finished_spans[4].links[1] + assert link.context.span_id == finished_spans[3].context.span_id + assert link.context.trace_id == finished_spans[3].context.trace_id + + tracing_helper.exporter.clear() + + # Receive all the sent spans. + async with client.get_queue_receiver(queue_name=queue_name) as receiver: + received_msgs = await receiver.receive_messages(max_message_count=3, max_wait_time=10) + for msg in received_msgs: + assert "foo" in str(msg) + await receiver.complete_message(msg) + + receive_spans = tracing_helper.exporter.get_finished_spans() + + # We expect 4 spans to have finished: 1 receive span, and 3 settlement spans. + assert len(receive_spans) == 4 + self._verify_receive(span=receive_spans[0], dest=queue_name, server_address=server_address, message_count=3) + + # Verify span links from receive. + assert len(receive_spans[0].links) == 3 + assert receive_spans[0].links[0].context.span_id == finished_spans[0].context.span_id + assert receive_spans[0].links[1].context.span_id == finished_spans[2].context.span_id + assert receive_spans[0].links[2].context.span_id == finished_spans[3].context.span_id + + # Verify settlement spans. + self._verify_complete(span=receive_spans[1], dest=queue_name, server_address=server_address) + self._verify_complete(span=receive_spans[2], dest=queue_name, server_address=server_address) + self._verify_complete(span=receive_spans[3], dest=queue_name, server_address=server_address) + + @pytest.mark.asyncio + @pytest.mark.live_test_only + async def test_servicebus_client_tracing_topic_async(self, config, tracing_helper): + connection_string = config["servicebus_connection_string"] + topic_name = config["servicebus_topic_name"] + subscription_name = config["servicebus_subscription_name"] + client = ServiceBusClient.from_connection_string(connection_string) + + with tracing_helper.tracer.start_as_current_span(name="root"): + async with client.get_topic_sender(topic_name) as sender: + + # Sending a single message + await sender.send_messages(ServiceBusMessage("Test foo message")) + + # Sending a batch of messages + message_batch = await sender.create_message_batch() + message_batch.add_message(ServiceBusMessage("First batch foo message")) + message_batch.add_message(ServiceBusMessage("Second batch foo message")) + await sender.send_messages(message_batch) + + send_spans = tracing_helper.exporter.get_finished_spans() + server_address = sender.fully_qualified_namespace + + # We expect 5 spans to have finished: 2 send spans, and 3 message spans. + assert len(send_spans) == 5 + + # Verify the spans from the first send. + self._verify_message(span=send_spans[0], dest=topic_name, server_address=server_address) + self._verify_send(span=send_spans[1], dest=topic_name, server_address=server_address, message_count=1) + + # Verify span links from single send. + link = send_spans[1].links[0] + assert link.context.span_id == send_spans[0].context.span_id + assert link.context.trace_id == send_spans[0].context.trace_id + + # Verify the spans from the second send. + self._verify_message(span=send_spans[2], dest=topic_name, server_address=server_address) + self._verify_message(span=send_spans[3], dest=topic_name, server_address=server_address) + self._verify_send(span=send_spans[4], dest=topic_name, server_address=server_address, message_count=2) + + # Verify span links from batch send. + assert len(send_spans[4].links) == 2 + link = send_spans[4].links[0] + assert link.context.span_id == send_spans[2].context.span_id + assert link.context.trace_id == send_spans[2].context.trace_id + + link = send_spans[4].links[1] + assert link.context.span_id == send_spans[3].context.span_id + assert link.context.trace_id == send_spans[3].context.trace_id + + tracing_helper.exporter.clear() + + # Receive all the sent spans. + async with client.get_subscription_receiver(topic_name, subscription_name) as receiver: + received_msgs = await receiver.receive_messages(max_message_count=3, max_wait_time=10) + for msg in received_msgs: + assert "foo" in str(msg) + await receiver.complete_message(msg) + + receive_spans = tracing_helper.exporter.get_finished_spans() + + # We expect 4 spans to have finished: 1 receive span, and 3 settlement spans. + assert len(receive_spans) == 4 + self._verify_receive(span=receive_spans[0], dest=topic_name, server_address=server_address, message_count=3) + + assert len(receive_spans[0].links) == 3 + assert receive_spans[0].links[0].context.span_id == send_spans[0].context.span_id + assert receive_spans[0].links[1].context.span_id == send_spans[2].context.span_id + assert receive_spans[0].links[2].context.span_id == send_spans[3].context.span_id + + # Verify settlement spans. + self._verify_complete(span=receive_spans[1], dest=topic_name, server_address=server_address) + self._verify_complete(span=receive_spans[2], dest=topic_name, server_address=server_address) + self._verify_complete(span=receive_spans[3], dest=topic_name, server_address=server_address) diff --git a/sdk/servicebus/azure-servicebus/tests/conftest.py b/sdk/servicebus/azure-servicebus/tests/conftest.py index ece04cdf8919..98c454ee08fd 100644 --- a/sdk/servicebus/azure-servicebus/tests/conftest.py +++ b/sdk/servicebus/azure-servicebus/tests/conftest.py @@ -3,7 +3,16 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # ------------------------------------------------------------------------- +import os +from typing import Generator + import pytest +from azure.core.settings import settings +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from devtools_testutils.sanitizers import ( add_remove_header_sanitizer, add_general_regex_sanitizer, @@ -30,3 +39,36 @@ def pytest_configure(config): config.addinivalue_line("markers", "liveTest: mark test to be a live test only") config.addinivalue_line("markers", "live_test_only: mark test to be a live test only") config.addinivalue_line("markers", "playback_test_only: mark test to be a playback test only") + + +class TracingTestHelper: + def __init__(self, tracer, exporter): + self.tracer = tracer + self.exporter = exporter + + +@pytest.fixture(scope="session", autouse=True) +def enable_otel_tracing(): + provider = TracerProvider() + trace.set_tracer_provider(provider) + + +@pytest.fixture(scope="function") +def tracing_helper() -> Generator[TracingTestHelper, None, None]: + settings.tracing_enabled = True + settings.tracing_implementation = None + span_exporter = InMemorySpanExporter() + processor = SimpleSpanProcessor(span_exporter) + trace.get_tracer_provider().add_span_processor(processor) + yield TracingTestHelper(trace.get_tracer(__name__), span_exporter) + settings.tracing_enabled = None + + +@pytest.fixture(scope="session") +def config(): + return { + "servicebus_connection_string": os.environ.get("SERVICEBUS_CONNECTION_STR"), + "servicebus_queue_name": os.environ.get("SERVICEBUS_QUEUE_NAME"), + "servicebus_topic_name": os.environ.get("SERVICEBUS_TOPIC_NAME"), + "servicebus_subscription_name": os.environ.get("SERVICEBUS_SUBSCRIPTION_NAME"), + } diff --git a/sdk/servicebus/azure-servicebus/tests/test_tracing_live.py b/sdk/servicebus/azure-servicebus/tests/test_tracing_live.py new file mode 100644 index 000000000000..32dc96a5000e --- /dev/null +++ b/sdk/servicebus/azure-servicebus/tests/test_tracing_live.py @@ -0,0 +1,216 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest + +from azure.servicebus import ServiceBusClient, ServiceBusMessage +from opentelemetry.trace import SpanKind + +class ServiceBusTracingTestBase: + """Base class containing shared verification methods for ServiceBus tracing tests.""" + + def _verify_span_attributes(self, *, span): + # Ensure all attributes are set and have a value. + for attr in span.attributes: + assert span.attributes[attr] is not None and span.attributes[attr] != "" + + def _verify_message(self, *, span, dest, server_address): + assert span.name == "ServiceBus.message" + assert span.kind == SpanKind.PRODUCER + self._verify_span_attributes(span=span) + assert span.attributes["az.namespace"] == "Microsoft.ServiceBus" + assert span.attributes["messaging.system"] == "servicebus" + assert span.attributes["messaging.destination.name"] == dest + assert span.attributes["net.peer.name"] == server_address + + def _verify_send(self, *, span, dest, server_address, message_count): + assert span.name == "ServiceBus.send" + assert span.kind == SpanKind.CLIENT + self._verify_span_attributes(span=span) + assert span.attributes["az.namespace"] == "Microsoft.ServiceBus" + assert span.attributes["messaging.system"] == "servicebus" + assert span.attributes["messaging.destination.name"] == dest + assert span.attributes["messaging.operation"] == "publish" + assert span.attributes["net.peer.name"] == server_address + if message_count > 1: + assert span.attributes["messaging.batch.message_count"] == message_count + + def _verify_receive(self, *, span, dest, server_address, message_count): + assert span.name == "ServiceBus.receive" + assert span.kind == SpanKind.CLIENT + self._verify_span_attributes(span=span) + assert span.attributes["az.namespace"] == "Microsoft.ServiceBus" + assert span.attributes["messaging.system"] == "servicebus" + assert span.attributes["messaging.destination.name"] == dest + assert span.attributes["messaging.operation"] == "receive" + assert span.attributes["net.peer.name"] == server_address + for link in span.links: + assert "enqueuedTime" in link.attributes + if message_count > 1: + assert span.attributes["messaging.batch.message_count"] == message_count + + def _verify_complete(self, *, span, dest, server_address): + assert span.name == "ServiceBus.complete" + assert span.kind == SpanKind.CLIENT + self._verify_span_attributes(span=span) + assert span.attributes["az.namespace"] == "Microsoft.ServiceBus" + assert span.attributes["messaging.system"] == "servicebus" + assert span.attributes["messaging.operation"] == "settle" + assert span.attributes["net.peer.name"] == server_address + assert span.attributes["messaging.destination.name"] == dest + +class TestServiceBusTracing(ServiceBusTracingTestBase): + + @pytest.mark.live_test_only + def test_servicebus_client_tracing_queue(self, config, tracing_helper): + + connection_string = config["servicebus_connection_string"] + queue_name = config["servicebus_queue_name"] + client = ServiceBusClient.from_connection_string(connection_string) + + with tracing_helper.tracer.start_as_current_span(name="root"): + with client.get_queue_sender(queue_name) as sender: + + # Sending a single message + sender.send_messages(ServiceBusMessage("Test foo message")) + + # Sending a batch of messages + message_batch = sender.create_message_batch() + message_batch.add_message(ServiceBusMessage("First batch foo message")) + message_batch.add_message(ServiceBusMessage("Second batch foo message")) + sender.send_messages(message_batch) + + + finished_spans = tracing_helper.exporter.get_finished_spans() + server_address = sender.fully_qualified_namespace + + # We expect 5 spans to have finished: 2 send spans, and 3 message spans. + assert len(finished_spans) == 5 + + # Verify the spans from the first send. + self._verify_message(span=finished_spans[0], dest=queue_name, server_address=server_address) + self._verify_send(span=finished_spans[1], dest=queue_name, server_address=server_address, message_count=1) + + # Verify span links from single send. + link = finished_spans[1].links[0] + assert link.context.span_id == finished_spans[0].context.span_id + assert link.context.trace_id == finished_spans[0].context.trace_id + + # Verify the spans from the second send. + self._verify_message(span=finished_spans[2], dest=queue_name, server_address=server_address) + self._verify_message(span=finished_spans[3], dest=queue_name, server_address=server_address) + self._verify_send(span=finished_spans[4], dest=queue_name, server_address=server_address, message_count=2) + + # Verify span links from batch send. + assert len(finished_spans[4].links) == 2 + link = finished_spans[4].links[0] + assert link.context.span_id == finished_spans[2].context.span_id + assert link.context.trace_id == finished_spans[2].context.trace_id + + link = finished_spans[4].links[1] + assert link.context.span_id == finished_spans[3].context.span_id + assert link.context.trace_id == finished_spans[3].context.trace_id + + tracing_helper.exporter.clear() + + # Receive all the sent spans. + receiver = client.get_queue_receiver(queue_name=queue_name) + with receiver: + received_msgs = receiver.receive_messages(max_message_count=3, max_wait_time=10) + for msg in received_msgs: + assert "foo" in str(msg) + receiver.complete_message(msg) + + receive_spans = tracing_helper.exporter.get_finished_spans() + + # We expect 4 spans to have finished: 1 receive span, and 3 settlement spans. + assert len(receive_spans) == 4 + self._verify_receive(span=receive_spans[0], dest=queue_name, server_address=server_address, message_count=3) + + # Verify span links from receive. + assert len(receive_spans[0].links) == 3 + assert receive_spans[0].links[0].context.span_id == finished_spans[0].context.span_id + assert receive_spans[0].links[1].context.span_id == finished_spans[2].context.span_id + assert receive_spans[0].links[2].context.span_id == finished_spans[3].context.span_id + + # Verify settlement spans. + self._verify_complete(span=receive_spans[1], dest=queue_name, server_address=server_address) + self._verify_complete(span=receive_spans[2], dest=queue_name, server_address=server_address) + self._verify_complete(span=receive_spans[3], dest=queue_name, server_address=server_address) + + @pytest.mark.live_test_only + def test_servicebus_client_tracing_topic(self, config, tracing_helper): + connection_string = config["servicebus_connection_string"] + topic_name = config["servicebus_topic_name"] + subscription_name = config["servicebus_subscription_name"] + client = ServiceBusClient.from_connection_string(connection_string) + + with tracing_helper.tracer.start_as_current_span(name="root"): + with client.get_topic_sender(topic_name) as sender: + + # Sending a single message + sender.send_messages(ServiceBusMessage("Test foo message")) + + # Sending a batch of messages + message_batch = sender.create_message_batch() + message_batch.add_message(ServiceBusMessage("First batch foo message")) + message_batch.add_message(ServiceBusMessage("Second batch foo message")) + sender.send_messages(message_batch) + + send_spans = tracing_helper.exporter.get_finished_spans() + server_address = sender.fully_qualified_namespace + + # We expect 5 spans to have finished: 2 send spans, and 3 message spans. + assert len(send_spans) == 5 + + # Verify the spans from the first send. + self._verify_message(span=send_spans[0], dest=topic_name, server_address=server_address) + self._verify_send(span=send_spans[1], dest=topic_name, server_address=server_address, message_count=1) + + # Verify span links from single send. + link = send_spans[1].links[0] + assert link.context.span_id == send_spans[0].context.span_id + assert link.context.trace_id == send_spans[0].context.trace_id + + # Verify the spans from the second send. + self._verify_message(span=send_spans[2], dest=topic_name, server_address=server_address) + self._verify_message(span=send_spans[3], dest=topic_name, server_address=server_address) + self._verify_send(span=send_spans[4], dest=topic_name, server_address=server_address, message_count=2) + + # Verify span links from batch send. + assert len(send_spans[4].links) == 2 + link = send_spans[4].links[0] + assert link.context.span_id == send_spans[2].context.span_id + assert link.context.trace_id == send_spans[2].context.trace_id + + link = send_spans[4].links[1] + assert link.context.span_id == send_spans[3].context.span_id + assert link.context.trace_id == send_spans[3].context.trace_id + + tracing_helper.exporter.clear() + + # Receive all the sent spans. + receiver = client.get_subscription_receiver(topic_name, subscription_name) + with receiver: + received_msgs = receiver.receive_messages(max_message_count=3, max_wait_time=10) + for msg in received_msgs: + assert "foo" in str(msg) + receiver.complete_message(msg) + + receive_spans = tracing_helper.exporter.get_finished_spans() + + # We expect 4 spans to have finished: 1 receive span, and 3 settlement spans. + assert len(receive_spans) == 4 + self._verify_receive(span=receive_spans[0], dest=topic_name, server_address=server_address, message_count=3) + + assert len(receive_spans[0].links) == 3 + assert receive_spans[0].links[0].context.span_id == send_spans[0].context.span_id + assert receive_spans[0].links[1].context.span_id == send_spans[2].context.span_id + assert receive_spans[0].links[2].context.span_id == send_spans[3].context.span_id + + # Verify settlement spans. + self._verify_complete(span=receive_spans[1], dest=topic_name, server_address=server_address) + self._verify_complete(span=receive_spans[2], dest=topic_name, server_address=server_address) + self._verify_complete(span=receive_spans[3], dest=topic_name, server_address=server_address)