Add type hints to logging/context.py (#6309)

* Add type hints to logging/context.py

Signed-off-by: neiljp (Neil Pilgrim) <github@kepier.clara.net>
This commit is contained in:
Neil Pilgrim 2020-03-07 09:57:26 -08:00 committed by GitHub
parent 1d66dce83e
commit 2bff4457d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 47 deletions

1
changelog.d/6309.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `logging/context.py`.

View file

@ -27,10 +27,15 @@ import inspect
import logging import logging
import threading import threading
import types import types
from typing import Any, List from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from typing_extensions import Literal
from twisted.internet import defer, threads from twisted.internet import defer, threads
if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
@ -91,7 +96,7 @@ class ContextResourceUsage(object):
"evt_db_fetch_count", "evt_db_fetch_count",
] ]
def __init__(self, copy_from=None): def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None:
"""Create a new ContextResourceUsage """Create a new ContextResourceUsage
Args: Args:
@ -101,27 +106,28 @@ class ContextResourceUsage(object):
if copy_from is None: if copy_from is None:
self.reset() self.reset()
else: else:
self.ru_utime = copy_from.ru_utime # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
self.ru_stime = copy_from.ru_stime self.ru_utime = copy_from.ru_utime # type: float
self.db_txn_count = copy_from.db_txn_count self.ru_stime = copy_from.ru_stime # type: float
self.db_txn_count = copy_from.db_txn_count # type: int
self.db_txn_duration_sec = copy_from.db_txn_duration_sec self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float
self.db_sched_duration_sec = copy_from.db_sched_duration_sec self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float
self.evt_db_fetch_count = copy_from.evt_db_fetch_count self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int
def copy(self): def copy(self) -> "ContextResourceUsage":
return ContextResourceUsage(copy_from=self) return ContextResourceUsage(copy_from=self)
def reset(self): def reset(self) -> None:
self.ru_stime = 0.0 self.ru_stime = 0.0
self.ru_utime = 0.0 self.ru_utime = 0.0
self.db_txn_count = 0 self.db_txn_count = 0
self.db_txn_duration_sec = 0 self.db_txn_duration_sec = 0.0
self.db_sched_duration_sec = 0 self.db_sched_duration_sec = 0.0
self.evt_db_fetch_count = 0 self.evt_db_fetch_count = 0
def __repr__(self): def __repr__(self) -> str:
return ( return (
"<ContextResourceUsage ru_stime='%r', ru_utime='%r', " "<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
"db_txn_count='%r', db_txn_duration_sec='%r', " "db_txn_count='%r', db_txn_duration_sec='%r', "
@ -135,7 +141,7 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count, self.evt_db_fetch_count,
) )
def __iadd__(self, other): def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
"""Add another ContextResourceUsage's stats to this one's. """Add another ContextResourceUsage's stats to this one's.
Args: Args:
@ -149,7 +155,7 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count += other.evt_db_fetch_count self.evt_db_fetch_count += other.evt_db_fetch_count
return self return self
def __isub__(self, other): def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
self.ru_utime -= other.ru_utime self.ru_utime -= other.ru_utime
self.ru_stime -= other.ru_stime self.ru_stime -= other.ru_stime
self.db_txn_count -= other.db_txn_count self.db_txn_count -= other.db_txn_count
@ -158,17 +164,20 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count -= other.evt_db_fetch_count self.evt_db_fetch_count -= other.evt_db_fetch_count
return self return self
def __add__(self, other): def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self) res = ContextResourceUsage(copy_from=self)
res += other res += other
return res return res
def __sub__(self, other): def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self) res = ContextResourceUsage(copy_from=self)
res -= other res -= other
return res return res
LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
class LoggingContext(object): class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a """Additional context for log formatting. Contexts are scoped within a
"with" block. "with" block.
@ -201,7 +210,14 @@ class LoggingContext(object):
class Sentinel(object): class Sentinel(object):
"""Sentinel to represent the root context""" """Sentinel to represent the root context"""
__slots__ = [] # type: List[Any] __slots__ = ["previous_context", "alive", "request", "scope"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.alive = None
self.request = None
self.scope = None
def __str__(self): def __str__(self):
return "sentinel" return "sentinel"
@ -235,7 +251,7 @@ class LoggingContext(object):
sentinel = Sentinel() sentinel = Sentinel()
def __init__(self, name=None, parent_context=None, request=None): def __init__(self, name=None, parent_context=None, request=None) -> None:
self.previous_context = LoggingContext.current_context() self.previous_context = LoggingContext.current_context()
self.name = name self.name = name
@ -250,7 +266,7 @@ class LoggingContext(object):
self.request = None self.request = None
self.tag = "" self.tag = ""
self.alive = True self.alive = True
self.scope = None self.scope = None # type: Optional[_LogContextScope]
self.parent_context = parent_context self.parent_context = parent_context
@ -261,13 +277,13 @@ class LoggingContext(object):
# the request param overrides the request from the parent context # the request param overrides the request from the parent context
self.request = request self.request = request
def __str__(self): def __str__(self) -> str:
if self.request: if self.request:
return str(self.request) return str(self.request)
return "%s@%x" % (self.name, id(self)) return "%s@%x" % (self.name, id(self))
@classmethod @classmethod
def current_context(cls): def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage """Get the current logging context from thread local storage
Returns: Returns:
@ -276,7 +292,9 @@ class LoggingContext(object):
return getattr(cls.thread_local, "current_context", cls.sentinel) return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod @classmethod
def set_current_context(cls, context): def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage """Set the current logging context in thread local storage
Args: Args:
context(LoggingContext): The context to activate. context(LoggingContext): The context to activate.
@ -291,7 +309,7 @@ class LoggingContext(object):
context.start() context.start()
return current return current
def __enter__(self): def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage""" """Enters this logging context into thread local storage"""
old_context = self.set_current_context(self) old_context = self.set_current_context(self)
if self.previous_context != old_context: if self.previous_context != old_context:
@ -304,7 +322,7 @@ class LoggingContext(object):
return self return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback) -> None:
"""Restore the logging context in thread local storage to the state it """Restore the logging context in thread local storage to the state it
was before this context was entered. was before this context was entered.
Returns: Returns:
@ -318,7 +336,6 @@ class LoggingContext(object):
logger.warning( logger.warning(
"Expected logging context %s but found %s", self, current "Expected logging context %s but found %s", self, current
) )
self.previous_context = None
self.alive = False self.alive = False
# if we have a parent, pass our CPU usage stats on # if we have a parent, pass our CPU usage stats on
@ -330,7 +347,7 @@ class LoggingContext(object):
# reset them in case we get entered again # reset them in case we get entered again
self._resource_usage.reset() self._resource_usage.reset()
def copy_to(self, record): def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or """Copy logging fields from this context to a log record or
another LoggingContext another LoggingContext
""" """
@ -341,14 +358,14 @@ class LoggingContext(object):
# we also track the current scope: # we also track the current scope:
record.scope = self.scope record.scope = self.scope
def copy_to_twisted_log_entry(self, record): def copy_to_twisted_log_entry(self, record) -> None:
""" """
Copy logging fields from this context to a Twisted log record. Copy logging fields from this context to a Twisted log record.
""" """
record["request"] = self.request record["request"] = self.request
record["scope"] = self.scope record["scope"] = self.scope
def start(self): def start(self) -> None:
if get_thread_id() != self.main_thread: if get_thread_id() != self.main_thread:
logger.warning("Started logcontext %s on different thread", self) logger.warning("Started logcontext %s on different thread", self)
return return
@ -358,7 +375,7 @@ class LoggingContext(object):
if not self.usage_start: if not self.usage_start:
self.usage_start = get_thread_resource_usage() self.usage_start = get_thread_resource_usage()
def stop(self): def stop(self) -> None:
if get_thread_id() != self.main_thread: if get_thread_id() != self.main_thread:
logger.warning("Stopped logcontext %s on different thread", self) logger.warning("Stopped logcontext %s on different thread", self)
return return
@ -378,7 +395,7 @@ class LoggingContext(object):
self.usage_start = None self.usage_start = None
def get_resource_usage(self): def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far. """Get resources used by this logcontext so far.
Returns: Returns:
@ -398,11 +415,13 @@ class LoggingContext(object):
return res return res
def _get_cputime(self): def _get_cputime(self) -> Tuple[float, float]:
"""Get the cpu usage time so far """Get the cpu usage time so far
Returns: Tuple[float, float]: seconds in user mode, seconds in system mode Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
""" """
assert self.usage_start is not None
current = get_thread_resource_usage() current = get_thread_resource_usage()
# Indicate to mypy that we know that self.usage_start is None. # Indicate to mypy that we know that self.usage_start is None.
@ -430,13 +449,13 @@ class LoggingContext(object):
return utime_delta, stime_delta return utime_delta, stime_delta
def add_database_transaction(self, duration_sec): def add_database_transaction(self, duration_sec: float) -> None:
if duration_sec < 0: if duration_sec < 0:
raise ValueError("DB txn time can only be non-negative") raise ValueError("DB txn time can only be non-negative")
self._resource_usage.db_txn_count += 1 self._resource_usage.db_txn_count += 1
self._resource_usage.db_txn_duration_sec += duration_sec self._resource_usage.db_txn_duration_sec += duration_sec
def add_database_scheduled(self, sched_sec): def add_database_scheduled(self, sched_sec: float) -> None:
"""Record a use of the database pool """Record a use of the database pool
Args: Args:
@ -447,7 +466,7 @@ class LoggingContext(object):
raise ValueError("DB scheduling time can only be non-negative") raise ValueError("DB scheduling time can only be non-negative")
self._resource_usage.db_sched_duration_sec += sched_sec self._resource_usage.db_sched_duration_sec += sched_sec
def record_event_fetch(self, event_count): def record_event_fetch(self, event_count: int) -> None:
"""Record a number of events being fetched from the db """Record a number of events being fetched from the db
Args: Args:
@ -464,10 +483,10 @@ class LoggingContextFilter(logging.Filter):
missing fields missing fields
""" """
def __init__(self, **defaults): def __init__(self, **defaults) -> None:
self.defaults = defaults self.defaults = defaults
def filter(self, record): def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record. """Add each fields from the logging contexts to the record.
Returns: Returns:
True to include the record in the log output. True to include the record in the log output.
@ -492,12 +511,13 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"] __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=None): def __init__(self, new_context: Optional[LoggingContext] = None) -> None:
if new_context is None: if new_context is None:
new_context = LoggingContext.sentinel self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
self.new_context = new_context else:
self.new_context = new_context
def __enter__(self): def __enter__(self) -> None:
"""Captures the current logging context""" """Captures the current logging context"""
self.current_context = LoggingContext.set_current_context(self.new_context) self.current_context = LoggingContext.set_current_context(self.new_context)
@ -506,7 +526,7 @@ class PreserveLoggingContext(object):
if not self.current_context.alive: if not self.current_context.alive:
logger.debug("Entering dead context: %s", self.current_context) logger.debug("Entering dead context: %s", self.current_context)
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context""" """Restores the current logging context"""
context = LoggingContext.set_current_context(self.current_context) context = LoggingContext.set_current_context(self.current_context)
@ -525,7 +545,9 @@ class PreserveLoggingContext(object):
logger.debug("Restoring dead context: %s", self.current_context) logger.debug("Restoring dead context: %s", self.current_context)
def nested_logging_context(suffix, parent_context=None): def nested_logging_context(
suffix: str, parent_context: Optional[LoggingContext] = None
) -> LoggingContext:
"""Creates a new logging context as a child of another. """Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's The nested logging context will have a 'request' made up of the parent context's
@ -546,10 +568,12 @@ def nested_logging_context(suffix, parent_context=None):
Returns: Returns:
LoggingContext: new logging context. LoggingContext: new logging context.
""" """
if parent_context is None: if parent_context is not None:
parent_context = LoggingContext.current_context() context = parent_context # type: LoggingContextOrSentinel
else:
context = LoggingContext.current_context()
return LoggingContext( return LoggingContext(
parent_context=parent_context, request=parent_context.request + "-" + suffix parent_context=context, request=str(context.request) + "-" + suffix
) )
@ -654,7 +678,10 @@ def make_deferred_yieldable(deferred):
return deferred return deferred
def _set_context_cb(result, context): ResultT = TypeVar("ResultT")
def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context""" """A callback function which just sets the logging context"""
LoggingContext.set_current_context(context) LoggingContext.set_current_context(context)
return result return result