Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(replay): add a generic log sampling filter and sample replay ingest logs #83049

Merged
merged 6 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/sentry/logging/handlers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import random
import re
from typing import Any

Expand Down Expand Up @@ -164,3 +165,23 @@ def emit(self, record, logger=None):
key = metrics_badchars_re.sub("", key)
key = ".".join(key.split(".")[:3])
metrics.incr(key, skip_internal=False)


class SamplingFilter(logging.Filter):
"""
A logging filter to sample logs with a fixed probability.

p -- probability the log record is emitted. Float in range [0.0, 1.0].
level -- sampling applies to log records with this level OR LOWER. Other records always pass through.
"""

def __init__(self, p: float, level: int | None = None):
super().__init__()
assert 0.0 <= p <= 1.0
self.sample_rate = p
self.level = logging.INFO if level is None else level

def filter(self, record: logging.LogRecord) -> bool:
if record.levelno <= self.level:
return random.random() < self.sample_rate
return True
8 changes: 7 additions & 1 deletion src/sentry/replays/consumers/recording_buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@
from sentry_kafka_schemas.schema_types.ingest_replay_recordings_v1 import ReplayRecording

from sentry.conf.types.kafka_definition import Topic, get_topic_codec
from sentry.logging.handlers import SamplingFilter
from sentry.models.project import Project
from sentry.replays.lib.storage import (
RecordingSegmentStorageMeta,
make_recording_filename,
storage_kv,
)
from sentry.replays.usecases.ingest import process_headers, track_initial_segment_event
from sentry.replays.usecases.ingest import (
LOG_SAMPLE_RATE,
process_headers,
track_initial_segment_event,
)
from sentry.replays.usecases.ingest.dom_index import (
ReplayActionsEvent,
emit_replay_actions,
Expand All @@ -72,6 +77,7 @@
from sentry.utils import json, metrics

logger = logging.getLogger(__name__)
logger.addFilter(SamplingFilter(LOG_SAMPLE_RATE))

RECORDINGS_CODEC: Codec[ReplayRecording] = get_topic_codec(Topic.INGEST_REPLAYS_RECORDINGS)

Expand Down
7 changes: 5 additions & 2 deletions src/sentry/replays/usecases/ingest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from sentry import options
from sentry.constants import DataCategory
from sentry.logging.handlers import SamplingFilter
from sentry.models.project import Project
from sentry.replays.lib.storage import (
RecordingSegmentStorageMeta,
Expand All @@ -25,10 +26,12 @@
from sentry.utils import json, metrics
from sentry.utils.outcomes import Outcome, track_outcome

logger = logging.getLogger("sentry.replays")

CACHE_TIMEOUT = 3600
COMMIT_FREQUENCY_SEC = 1
LOG_SAMPLE_RATE = 0.01

logger = logging.getLogger("sentry.replays")
logger.addFilter(SamplingFilter(LOG_SAMPLE_RATE))


class ReplayRecordingSegment(TypedDict):
Expand Down
82 changes: 81 additions & 1 deletion tests/sentry/logging/test_handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import logging
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any
from unittest import mock

import pytest

from sentry.logging.handlers import GKEStructLogHandler, JSONRenderer, StructLogHandler
from sentry.logging.handlers import (
GKEStructLogHandler,
JSONRenderer,
SamplingFilter,
StructLogHandler,
)


@pytest.fixture
Expand All @@ -26,6 +33,36 @@ def __str__(self) -> str:
return SNAFU()


@contextmanager
def filter_context(
logger: logging.Logger, filters: list[logging.Filter | Callable[[logging.LogRecord], bool]]
):
"""Manages adding and cleaning up log filters"""
for f in filters:
logger.addFilter(f)
try:
yield
finally:
for f in filters:
logger.removeFilter(f)


@contextmanager
def level_context(level: int):
curr_level = logging.getLogger().level
logging.basicConfig(level=level)
try:
yield
finally:
logging.basicConfig(level=curr_level)


@pytest.fixture
def set_level_debug():
with level_context(logging.DEBUG):
yield


def make_logrecord(
*,
name: str = "name",
Expand Down Expand Up @@ -131,3 +168,46 @@ def test_gke_emit() -> None:
event="msg",
**{"logging.googleapis.com/labels": {"name": "name"}},
)


@mock.patch("random.random", lambda: 0.1)
def test_sampling_filter(caplog, set_level_debug):
logger = logging.getLogger(__name__)
with filter_context(logger, [SamplingFilter(0.2)]):
logger.info("msg1")
logger.info("message.2")

with filter_context(logger, [SamplingFilter(0.05)]):
logger.info("msg1")
logger.info("message.2")

captured_msgs = list(map(lambda r: r.msg, caplog.records))
assert sorted(captured_msgs) == ["message.2", "msg1"]


@mock.patch("random.random", lambda: 0.1)
def test_sampling_filter_level(caplog, set_level_debug):
logger = logging.getLogger(__name__)
with filter_context(logger, [SamplingFilter(0.05, level=logging.WARNING)]):
logger.debug("debug")
logger.info("info")
logger.warning("warning")
logger.error("error")
logger.critical("critical")

captured_msgs = list(map(lambda r: r.msg, caplog.records))
assert sorted(captured_msgs) == ["critical", "error"]


@mock.patch("random.random", lambda: 0.1)
def test_sampling_filter_level_default(caplog, set_level_debug):
logger = logging.getLogger(__name__)
with filter_context(logger, [SamplingFilter(0.05)]):
logger.debug("debug")
logger.info("info")
logger.warning("warning")
logger.error("error")
logger.critical("critical")

captured_msgs = list(map(lambda r: r.msg, caplog.records))
assert sorted(captured_msgs) == ["critical", "error", "warning"]
Loading