import asyncio
import collections.abc
import enum
import inspect
import logging
import sys
from concurrent.futures import Executor
from typing import (
cast,
Any,
Awaitable,
Callable,
Coroutine,
Generic,
Iterable,
Iterator,
Literal,
NamedTuple,
Optional,
TypeVar,
Union,
TYPE_CHECKING
)
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
# define TypeAlias as a no-op for type checkers in older Python versions
if TYPE_CHECKING:
TypeAlias = type
else:
TypeAlias = None
from ..product._product import Product, DiffValueTuple
logger = logging.getLogger('freshpointsync.update')
"""Logger for the `freshpointsync.update` module."""
T = TypeVar('T')
class SafeAsyncTaskRunner:
"""A utility for running asynchronous and synchronous tasks in
a non-blocking manner with optional error handling and
the ability to await or cancel all running tasks.
"""
def __init__(self, executor: Optional[Executor] = None) -> None:
"""Initialize a `SafeAsyncTaskRunner` instance with
an optional executor.
Args:
executor (Optional[Executor]): A `concurrent.futures.Executor`
object to be used for running synchronous functions in
a non-blocking manner. If None, a default executor is used.
For more information, see the asyncio event loop's
`run_in_executor` documentation. Defaults to None.
"""
self.tasks: set[asyncio.Task] = set()
"""A set that stores all running or pending tasks
associated with calls to the `run_async` method.
"""
self.futures: set[asyncio.Future] = set()
"""A set that stores all running or pending futures
associated with calls to the `run_sync` method.
"""
self.executor = executor
"""An optional `concurrent.futures.Executor` object to be used
for running synchronous functions in the `run_sync` method.
"""
@staticmethod
def _log_task_or_future_done(
task_or_future: Union[asyncio.Task, asyncio.Future],
type_: Literal["Task", "Future"],
name: str
) -> None:
"""Log the result of a completed asyncio Task or Future.
Logs:
- A debug message if it completed successfully or
if a raised exception was caught or cancelled.
- An warning with details if an exception was raised.
Args:
task_or_future: The Task or Future object to log.
type_: Specifies whether the object is a "Task" or a "Future".
name: The name of the Task or Future for identification in logs.
"""
if task_or_future.cancelled():
logger.debug('%s "%s" was cancelled', type_, name)
elif task_or_future.exception() is None:
logger.debug('%s "%s" finished', type_, name)
else:
exc = task_or_future.exception()
exc_type, exc_desc = type(exc).__name__, str(exc)
if exc_desc:
logger.warning(
'%s "%s" raised an exception (%s: %s)',
type_, name, exc_type, exc_desc
)
else:
logger.warning(
'%s "%s" raised an exception (%s)',
type_, name, exc_type
)
@staticmethod
def _log_caught_exception(
exc: Exception,
type_: Literal["Task", "Future"],
name: str,
) -> None:
"""Log a warning for an exception caught from a Task or Future
in a safe runner.
Logs:
- A warning with the exception type and description, if available.
Args:
exc: The exception instance that was caught.
type_: Indicates whether the exception came from
a "Task" or a "Future".
name: The name of the Task or Future for identification in logs.
"""
exc_type, exc_desc = type(exc).__name__, str(exc)
if exc_desc:
logger.warning(
'%s "%s" failed (%s: %s)',
type_, name, exc_type, exc_desc
)
else:
logger.warning(
'%s "%s" failed (%s)',
type_, name, exc_type
)
@staticmethod
def _get_awaitable_name(awaitable: Awaitable) -> str:
"""Retrieve a human-readable name of an awaitable object.
Args:
awaitable (Awaitable): The awaitable object for which to
retrieve the name.
Returns:
str: A string representing the name of the awaitable.
Defaults to the `repr` of the awaitable if
a specific name cannot be determined.
"""
try:
if isinstance(awaitable, asyncio.Task):
return awaitable.get_name()
elif inspect.iscoroutine(awaitable):
return awaitable.cr_code.co_name
return repr(awaitable)
except Exception: # in case "inspect.iscoroutine" fails
return repr(awaitable)
async def _run_async_safe(self, awaitable: Awaitable[T]) -> Optional[T]:
"""Wrap an awaitable in a coroutine with added error handling that
catches and logs exceptions. Note that the `asyncio.CancelledError`
exceptions are re-raised to propagate cancellation.
Args:
awaitable (Awaitable[T]): The awaitable object to run.
Returns:
Optional[T]: The result of the awaitable if it completes\
successfully, `None` if an exception occurs.
"""
try:
return await awaitable
except asyncio.CancelledError:
raise # re-raise to ensure cancellation is propagated
except Exception as exc:
awaitable_name = self._get_awaitable_name(awaitable)
self._log_caught_exception(exc, "Task", awaitable_name)
return None
def run_async(
self,
func: Callable[..., Coroutine[Any, Any, T]],
*func_args: Any,
run_safe: bool = True,
done_callback: Optional[Callable[[asyncio.Task], Any]] = None
) -> asyncio.Task[Optional[T]]:
"""Schedule a coroutine function to be run,
optionally with error handling and a completion callback.
This method is specifically designed for running coroutine functions
that are asynchronous in nature. Providing a synchronous (non-async)
function to this method will fail at runtime.
Args:
func (Callable[..., Coroutine[Any, Any, T]]): The coroutine
function to be run.
*func_args (Any): The arguments to run the coroutine function with.
run_safe (bool): If True, the potential exceptions raised by
the coroutine are caught and logged, and the result is set to
None in case of an error. If False, exceptions are propagated
and must be handled by the caller. Defaults to True.
done_callback (Optional[Callable[[asyncio.Task], Any]]):
An optional callback to be called when the task completes.
Returns:
asyncio.Task[Optional[T]]: An asyncio task object representing
the scheduled coroutine. The task can be awaited to obtain
the result of the coroutine function call or cancelled.
"""
name = self._get_func_name(func)
logger.debug('Scheduling task "%s"', name)
if run_safe:
task = asyncio.create_task(self._run_async_safe(func(*func_args)))
else:
task = asyncio.create_task(func(*func_args))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
task.add_done_callback(
lambda t: self._log_task_or_future_done(t, "Task", name)
)
if done_callback:
task.add_done_callback(done_callback)
return task
@staticmethod
def _get_func_name(func: Callable[..., T]) -> str:
"""Retrieve a human-readable name of a function.
Args:
func (Callable[..., T]): The function for which
to retrieve the name.
Returns:
str: The name of the function.
"""
try:
return func.__name__
except AttributeError:
return repr(func)
def _run_sync_safe(self, func: Callable[..., T], *args) -> Optional[T]:
"""Call a synchronous function with added error handling that
catches and logs exceptions. Note that the `asyncio.CancelledError`
exceptions are re-raised to propagate cancellation.
Args:
func (Callable[..., T]): The synchronous function to run.
*args: Arguments to run the function with.
Returns:
Optional[T]: The result of the function if it completes\
successfully, `None` if an exception occurs.
"""
try:
return func(*args)
except asyncio.CancelledError:
raise # re-raise to ensure cancellation is propagated
except Exception as exc:
name = self._get_func_name(func)
self._log_caught_exception(exc, "Future", name)
return None
def run_sync(
self,
func: Callable[..., T],
*func_args,
run_safe: bool = True,
done_callback: Optional[Callable[[asyncio.Future], Any]] = None
) -> asyncio.Future[Optional[T]]:
"""Schedule a synchronous function to be run in a non-blocking manner,
optionally with error handling and a completion callback.
This method is specifically designed for synchronous functions that
block. It utilizes an executor, allowing for concurrent execution of
such functions without blocking the asyncio event loop. Providing
an asynchronous (async) function will fail at runtime.
Args:
func (Callable[..., T]): The synchronous function to be run.
*func_args (Any): The arguments to run the function with.
run_safe (bool): If True, the potential exceptions raised by
the synchronous function are caught and logged, and the result
is set to None in case of an error. If False, exceptions
are propagated and must be handled by the caller.
Defaults to True.
done_callback (Optional[Callable[[asyncio.Future], Any]]):
An optional callback to be called when the future completes.
Returns:
asyncio.Future[Optional[T]]: An asyncio future object representing\
the scheduled execution of the synchronous function.\
The future can be awaited to obtain the result of the function\
call. Note that cancellation of the future is not possible if\
the function is already running (for more information, see\
the `concurrent.futures` documentation on cancellation of\
the future objects).
"""
name = self._get_func_name(func)
logger.debug('Scheduling future "%s"', name)
loop = asyncio.get_running_loop()
future: asyncio.Future[Optional[T]]
if run_safe:
future = loop.run_in_executor(
self.executor, self._run_sync_safe, func, *func_args
)
else:
future = loop.run_in_executor(
self.executor, func, *func_args
)
future.add_done_callback(self.futures.discard)
future.add_done_callback(
lambda f: self._log_task_or_future_done(f, "Future", name)
)
if done_callback:
future.add_done_callback(done_callback)
self.futures.add(future)
return future
async def await_all(self) -> None:
"""Wait for all scheduled asynchronous and synchronous tasks
to complete.
This method gathers all active asyncio tasks and futures and awaits for
their completion. It is particularly useful for ensuring that
all background operations have finished before proceeding to another
stage of the application or gracefully shutting down the application.
Note that this method ensures that the tracking sets of tasks and
futures have been cleared after their completion, effectively resetting
the runner's state.
"""
logger.debug('Awaiting all tasks and futures')
tasks = set(self.tasks)
futures = set(self.futures)
await asyncio.gather(*tasks, *futures)
self.tasks.clear()
self.futures.clear()
async def cancel_all(self) -> None:
"""Attempt to cancel all active tasks and futures.
This method gathers and cancells all active asyncio tasks and attempts
to cancel all active futures created by running synchronous functions.
For the latter, the cancellation if only possible if the future has not
started running yet (for more information, see the `concurrent.futures`
documentation on cancellation of the future objects). The method is
particularly useful for cancelling all background operations before
proceeding to another stage of the application or gracefully shutting
down the application.
Note that this method ensures that the tracking set of tasks have been
cleared after their completion. The set of futures is not cleared, as
the cancellation of the futures is not guaranteed to be successful.
"""
logger.debug('Cancelling all tasks and futures')
# let the event loop run to allow for task cancellation
# (helps if "cancel_all" is called right after a task is created)
await asyncio.sleep(0)
tasks = set(self.tasks)
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
self.tasks.clear()
futures = set(self.futures)
for future in futures:
if future.cancel():
self.futures.remove(future)
[docs]
class ProductUpdateEvent(enum.Enum):
"""An enumeration of different types of product updates.
This enum is used to represent various kinds of changes or updates that can
occur to a product on the product page. Each member represents a specific
type of update event for easy identification and handling.
"""
PRODUCT_ADDED = "product_added"
"""Indicates that a new product has been listed on the product page."""
PRODUCT_UPDATED = "product_updated"
"""Indicates any update to a product's information and state."""
QUANTITY_UPDATED = "quantity_updated"
"""Indicates that the stock quantity of a product has been updated."""
PRICE_UPDATED = "price_updated"
"""Indicates that the pricing of a product has changed, including both
full price and current sale price."""
OTHER_UPDATED = "other_updated"
"""Indicates an update to the product's metadata,
for example, a change of the illustration picture URL.
"""
PRODUCT_REMOVED = "product_removed"
"""Indicates that a product has been removed from the product page."""
[docs]
class ProductUpdateContext(collections.abc.Mapping):
"""An update event context data wrapper for product update events with
a mapping-like access to the data along with direct access to specific
product state information via the `product_old` and `product_new` keys and
the event type via the `event` key.
This class is designed to encapsulate the update event context data
passed to event handlers during product update events.
"""
def __init__(self, __kwargs: dict[Any, Any]) -> None:
super().__init__()
self.__kwargs = __kwargs
def __str__(self) -> str:
"""Returns a string representation of the context data."""
return str(self.__kwargs)
def __repr__(self) -> str:
"""Returns a string representation of the context data object
initialization.
"""
return f'{self.__class__.__name__}({self.__kwargs})'
def __getitem__(self, key: Any) -> Any:
"""Returns the value for the given key from the internal context data.
Args:
key (Any): The key to look up in the context data.
Returns:
Any: The value associated with the specified key.
"""
return self.__kwargs[key]
def __iter__(self) -> Iterator[Any]:
"""Returns an iterator over the keys of the internal context data.
Returns:
Iterator[Any]: An iterator over the keys of the context data.
"""
return iter(self.__kwargs)
def __len__(self) -> int:
"""Returns the number of items in the internal context data.
Returns:
int: The number of items in the context data.
"""
return len(self.__kwargs)
def __get_product_attr(self, attr: str, default: T) -> T:
"""Retrieve a specific attribute from the product state data.
Args:
attr (str): The attribute to retrieve from the product state data.
default (T): The default value to return if the attribute
is not found.
Returns:
T: The value of the attribute if found, otherwise the default value.
"""
if attr in self.__kwargs:
return self.__kwargs[attr]
if self.product_new:
return getattr(self.product_new, attr)
elif self.product_old:
return getattr(self.product_old, attr)
return default
@property
def event(self) -> ProductUpdateEvent:
"""The type of product update event that occurred."""
return self.__kwargs['event']
@property
def product_id(self) -> int:
"""ID of the product being updated.
If not available, defaults to 0.
"""
return self.__get_product_attr('id_', 0)
@property
def product_name(self) -> str:
"""Name of the product being updated.
If not available, defaults to an empty string.
"""
return self.__get_product_attr('name', '')
@property
def product_name_lowercase_ascii(self) -> str:
"""Lowercase ASCII representation of the product name.
If not available, defaults to an empty string.
"""
return self.__get_product_attr('name_lowercase_ascii', '')
@property
def location_id(self) -> int:
"""ID of the product location.
If not available, defaults to 0.
"""
return self.__get_product_attr('location_id', 0)
@property
def location_name(self) -> str:
"""Name of the product location.
If not available, defaults to an empty string.
"""
return self.__get_product_attr('location_name', '')
@property
def location_name_lowercase_ascii(self) -> str:
"""Lowercase ASCII representation of the location name.
If not available, defaults to an empty string.
"""
return self.__get_product_attr('location_name_lowercase_ascii', '')
@property
def product_new(self) -> Optional[Product]:
"""New state of the product after the update.
If not available, defaults to None.
"""
return self.__kwargs['product_new']
@property
def product_old(self) -> Optional[Product]:
"""Previous state of the product before the update.
If not available, defaults to None.
"""
return self.__kwargs['product_old']
[docs]
def asdict(self) -> dict[str, Any]:
"""Convert the context data to a dictionary.
Returns:
dict[str, Any]: A dictionary representation of the context data.
"""
return self.__kwargs.copy()
[docs]
class HandlerValidator:
"""A utility class for validating handler functions or methods based on
their signature and whether they are asynchronous or synchronous.
"""
@classmethod
def _is_valid_is_async(cls, handler: Any) -> tuple[bool, bool]:
"""Determine if a given handler is valid based on its signature and
whether it is an asynchronous coroutine or a synchronous function
or method.
Args:
handler: The handler to be validated. This can be any callable
(a function, a method, or an object with a `__call__` method).
Returns:
tuple(bool, bool): A tuple of two boolean values, where\
the first value indicates whether the handler is valid (True)\
or not (False) and the second value indicates whether\
the handler is an asynchronous (True) or synchronous (False).
"""
if not callable(handler):
return False, False
if inspect.iscoroutinefunction(handler):
sig = inspect.signature(handler)
return (len(sig.parameters) == 1), True
elif inspect.isfunction(handler) or inspect.ismethod(handler):
sig = inspect.signature(handler)
return (len(sig.parameters) == 1), False
elif hasattr(handler, '__call__'):
return cls._is_valid_is_async(handler.__call__)
return False, False
[docs]
@classmethod
def is_valid_async_handler(cls, handler: Any) -> bool:
"""Check if a handler is an asynchronous coroutine function with
a valid signature (accepts exactly one argument).
Args:
handler: The handler to be validated.
Returns:
bool: True if the handler is valid and asynchronous,\
otherwise, False.
"""
is_valid, is_async = cls._is_valid_is_async(handler)
return is_valid and is_async
[docs]
@classmethod
def is_valid_sync_handler(cls, handler: Any) -> bool:
"""Check if a handler is a synchronous function or method with a valid
signature (accepts exactly one argument).
Args:
handler: The handler to be validated.
Returns:
bool: True if the handler is valid and synchronous,\
otherwise, False.
"""
is_valid, is_async = cls._is_valid_is_async(handler)
return is_valid and not is_async
[docs]
@classmethod
def is_valid_handler(cls, handler: Any) -> bool:
"""Check if a handler has a valid signature and is a valid function or
method, regardless of whether it is synchronous or asynchronous.
Args:
handler: The handler to be validated.
Returns:
bool: True if the handler is valid, otherwise, False.
"""
is_valid, _ = cls._is_valid_is_async(handler)
return is_valid
[docs]
def is_valid_handler(handler: Any) -> bool:
"""Validate a handler based on its signature and whether it is an
asynchronous coroutine or a synchronous function or method.
Args:
handler (Any): The handler to be validated.
Returns:
bool: True if the handler is valid, otherwise, False.
"""
return HandlerValidator.is_valid_handler(handler)
SyncHandler: TypeAlias = Callable[
[ProductUpdateContext], None
]
"""Type alias for a synchronous handler function. Expects a single argument
of type `ProductUpdateContext` and returns `None`.
"""
AsyncHandler: TypeAlias = Callable[
[ProductUpdateContext], Coroutine[Any, Any, None]
]
"""Type alias for an asynchronous handler function. Expects a single argument
of type `ProductUpdateContext` and returns a coroutine that resolves to `None`.
"""
Handler: TypeAlias = Union[SyncHandler, AsyncHandler]
"""Type alias for a handler function or method. Can be either a synchronous
callable that accepts a single argument of type `ProductUpdateContext` and
returns `None`, or an asynchronous callable that accepts a single argument
of type `ProductUpdateContext` and returns a coroutine that resolves to `None`.
"""
class HandlerExecParamsTuple(NamedTuple):
"""A named tuple for storing handler execution parameters."""
call_safe: bool
"""A flag indicating whether the handler should be executed in a "safe"
manner, meaning any exceptions raised by the handler will be caught and
handled.
"""
done_callback: Optional[Callable[[asyncio.Future], Any]]
"""An optional callback function to be called when the handler completes
its execution. Depending of the type of the handler, the callback
receives an `asyncio.Task` or `asyncio.Future` object as its argument,
which represents the eventual result of the handler's execution.
"""
HandlerType = TypeVar('HandlerType', SyncHandler, AsyncHandler)
"""A `TypeVar` variable for a handler function or method."""
class HandlerDataTuple(NamedTuple, Generic[HandlerType]):
"""A named tuple for storing handler data."""
handler: HandlerType
"""The handler function to be executed after a certain event occurs."""
exec_params: HandlerExecParamsTuple
"""The hanlder function execution parameters."""
SyncHandlerList: TypeAlias = list[HandlerDataTuple[SyncHandler]]
"""A list of `HandlerDataTuple` records of synchronous handlers."""
AsyncHandlerList: TypeAlias = list[HandlerDataTuple[AsyncHandler]]
"""A list of `HandlerDataTuple` records of asynchronous handlers."""
[docs]
class ProductUpdateEventPublisher:
"""A publisher and a subscription manager for product update events.
This class maintains a registry of event handlers for different types of
product update events. It allows for subscribing or unsubscribing handlers
to specific events and facilitates broadcasting events to all relevant
handlers.
"""
def __init__(self) -> None:
self.context: dict[Any, Any] = {} # global persistent context
self.runner = SafeAsyncTaskRunner()
self.sync_subscribers: dict[ProductUpdateEvent, SyncHandlerList] = {}
self.async_subscribers: dict[ProductUpdateEvent, AsyncHandlerList] = {}
@staticmethod
def _is_in_subscribers_list(
handler: Handler, subscribers: list[HandlerDataTuple]
) -> bool:
"""Check if a handler is already present in the list of subscribers.
This method iterates over the subscribers list to check if the provided
handler matches any handler in the list based on object identity.
Args:
handler: The handler function or method to check.
subscribers: A list of `HandlerDataTuple` objects representing
the current subscribers.
Returns:
bool: True if the handler is found in the list of subscribers,\
otherwise, False.
"""
return any(existing.handler == handler for existing in subscribers)
@staticmethod
def _remove_from_subscribers_list(
handler: Handler, subscribers: list[HandlerDataTuple]
) -> None:
"""
Remove a handler from the list of subscribers.
This method iterates over the subscribers list to find and remove
the provided handler based on object identity. If the handler
is not found in the list, no action is taken.
Args:
handler: The handler function or method to remove.
subscribers: A list of `HandlerDataTuple` objects representing
the current subscribers.
"""
for i, existing in enumerate(subscribers):
if existing.handler == handler:
del subscribers[i]
break
def _subscribe_sync(
self,
event: ProductUpdateEvent,
handler: Handler,
call_safe: bool = True,
handler_done_callback: Optional[Callable[[asyncio.Future], Any]] = None
) -> None:
"""Subscribe a synchronous handler to a specific product update event.
This internal method adds a synchronous handler to the event's
subscribers list, ensuring that the handler and its
execution parameters are stored correctly.
Args:
event: The product update event to subscribe the handler to.
handler: The synchronous handler to be subscribed.
call_safe: Indicates whether the handler should be called in
a "safe" manner, with exceptions caught and handled.
handler_done_callback: An optional callback function to be called
when the handler completes its execution.
"""
subscribers = self.sync_subscribers.setdefault(event, [])
if not self._is_in_subscribers_list(handler, subscribers):
params = HandlerExecParamsTuple(call_safe, handler_done_callback)
handler = cast(SyncHandler, handler)
subscribers.append(HandlerDataTuple(handler, params))
def _subscribe_async(
self,
event: ProductUpdateEvent,
handler: Handler,
call_safe: bool = True,
handler_done_callback: Optional[Callable[[asyncio.Future], Any]] = None
) -> None:
"""Subscribe an asynchronous handler to a specific product update
event.
This internal method adds an asynchronous handler to the event's
subscribers list, ensuring that the handler and its
execution parameters are stored correctly.
Args:
event: The product update event to subscribe the handler to.
handler: The asynchronous handler to be subscribed.
call_safe: Indicates whether the handler should be called in
a "safe" manner, with exceptions caught and handled.
handler_done_callback: An optional callback function to be called
when the handler completes its execution.
"""
subscribers = self.async_subscribers.setdefault(event, [])
if not self._is_in_subscribers_list(handler, subscribers):
params = HandlerExecParamsTuple(call_safe, handler_done_callback)
handler = cast(AsyncHandler, handler)
subscribers.append(HandlerDataTuple(handler, params))
@staticmethod
def _validate_event(
event: Union[ProductUpdateEvent, Iterable[ProductUpdateEvent], None]
) -> list[ProductUpdateEvent]:
"""Validate the provided event(s) and return a list of event objects.
Args:
event (Union[ProductUpdateEvent, Iterable[ProductUpdateEvent],\
None]): The event or events to validate.
Raises:
TypeError: If the event is not a valid type.
Returns:
list[ProductUpdateEvent]: A list of validated event objects.
"""
if event is None:
return list(ProductUpdateEvent)
if isinstance(event, ProductUpdateEvent):
return [event]
if isinstance(event, Iterable) and not isinstance(event, str):
return [ProductUpdateEvent(e) for e in event]
raise TypeError("Event must be a ProductUpdateEvent or an iterable.")
[docs]
def subscribe(
self,
handler: Handler,
event: Union[
ProductUpdateEvent, Iterable[ProductUpdateEvent], None
] = None,
call_safe: bool = True,
handler_done_callback: Optional[Callable[[asyncio.Future], Any]] = None
) -> None:
"""Subscribe a handler to specific product update event(s). The handler
will be invoked when the event is posted, with the event context
passed as an argument.
The handler can be an asynchronous function, method, or any callable
object that accepts exactly one argument (a `ProductUpdateContext`
object) and returns `None` or a coroutine that resolves to `None`.
Args:
handler (Handler): The function or callable to invoke for
the event(s).
event (Union[ProductUpdateEvent, Iterable[ProductUpdateEvent],\
None], optional): The type of product update event(s) to
subscribe to. If None, the handler will be subscribed to
all events.
call_safe (bool, optional): If True, exceptions raised by
the handler are caught and logged. If False, exceptions are
propagated and must be handled by the caller. Defaults to True.
handler_done_callback (Optional[Callable[[asyncio.Future], Any]]):
Optional function to be called when the handler completes
execution. Depending on the type of the handler, the callback
receives an `asyncio.Task` or `asyncio.Future` object as its
argument, which represents the return value of the callback
execution. Defaults to None.
Raises:
TypeError: If the handler does not have a valid signature.
"""
events = self._validate_event(event)
is_valid, is_async = HandlerValidator._is_valid_is_async(handler)
if not is_valid:
raise TypeError("Handler signature is invalid.")
callback = handler_done_callback
if is_async:
for e in events:
self._subscribe_async(e, handler, call_safe, callback)
else:
for e in events:
self._subscribe_sync(e, handler, call_safe, callback)
[docs]
def unsubscribe(
self,
handler: Optional[Handler] = None,
event: Union[
ProductUpdateEvent, Iterable[ProductUpdateEvent], None
] = None
) -> None:
"""Unsubscribe a handler from specific product update event(s),
or all handlers if no specific handler is provided. The unsubscribed
handler will no longer be invoked when the event is posted.
Args:
handler (Handler): The handler to be unsubscribed from the
event(s). if None, all handlers for the event are unsubscribed.
event (Union[ProductUpdateEvent, Iterable[ProductUpdateEvent],\
None], optional): The type of product update event(s)
to unsubscribe from. If None, the handler(s) will be subscribed
from all events.
"""
events = self._validate_event(event)
for e in events:
async_subs = self.async_subscribers.pop(e, None)
sync_subs = self.sync_subscribers.pop(e, None)
if handler is None:
continue
if async_subs:
if self._is_in_subscribers_list(handler, async_subs):
self._remove_from_subscribers_list(handler, async_subs)
if sync_subs:
if self._is_in_subscribers_list(handler, sync_subs):
self._remove_from_subscribers_list(handler, sync_subs)
[docs]
def is_subscribed(
self,
handler: Optional[Handler] = None,
event: Union[
ProductUpdateEvent, Iterable[ProductUpdateEvent], None
] = None
) -> bool:
"""Check if there are any subscribers for the given event(s).
Args:
handler (Optional[Handler], optional): The handler to check for
subscription. If None, all handlers are checked.
event (Union[ProductUpdateEvent, Iterable[ProductUpdateEvent],\
None], optional): The type of product update event(s) to check for
subscribers. If None, all events are checked.
Returns:
bool: True if there are subscribers for the event, False otherwise.
"""
events = self._validate_event(event)
for e in events:
if e in self.async_subscribers or e in self.sync_subscribers:
if handler is None:
return True
async_subs = self.async_subscribers.get(e, [])
sync_subs = self.sync_subscribers.get(e, [])
if self._is_in_subscribers_list(handler, async_subs):
return True
if self._is_in_subscribers_list(handler, sync_subs):
return True
return False
[docs]
def get_context(
self,
event: ProductUpdateEvent,
product_new: Optional[Product],
product_old: Optional[Product],
**kwargs: Any
) -> ProductUpdateContext:
"""Construct a `ProductUpdateContext` instance from the provided
product and additional parameters, merging the global context data with
the specific product information and any additional keyword arguments
(the latter takes precedence).
Args:
product_new (Optional[Product]): The updated product instance,
or None if not applicable.
product_old (Optional[Product]): The original product instance
before the update, or None if not applicable.
**kwargs (Any): Additional context parameters to be included.
Returns:
ProductUpdateContext: A context object populated with the global\
context data and the provided product information and
parameters.
"""
context_kwargs = {**self.context, **kwargs}
context_kwargs['event'] = event
context_kwargs['product_new'] = product_new
context_kwargs['product_old'] = product_old
return ProductUpdateContext(context_kwargs)
[docs]
def post(
self,
event: ProductUpdateEvent,
product_new: Optional[Product],
product_old: Optional[Product],
**kwargs: Any
) -> None:
"""Notify registered handlers about a specific product update event.
This method constructs a `ProductUpdateContext` with new and old
product details, along with any additional keyword arguments. It then
broadcasts the event to all relevant handlers, asynchronous and
synchronous alike, registered for the event type.
Args:
event (ProductUpdateEvent): The type of product update event
to be broadcasted.
product_new (Optional[Product]): The updated product instance,
or None if there's no new product.
product_old (Optional[Product]): The original product instance
before update, or None if unavailable.
**kwargs (Any): Additional keyword arguments to be included
in the event context. If any of the keyword arguments
overlaps with the context data, it takes precedence.
"""
event = ProductUpdateEvent(event)
context: Optional[ProductUpdateContext] = None
if event in self.async_subscribers:
context = self.get_context(
event, product_new, product_old, **kwargs
)
for async_sub in self.async_subscribers[event]:
self.runner.run_async(
async_sub.handler,
context,
run_safe=async_sub.exec_params.call_safe,
done_callback=async_sub.exec_params.done_callback
)
if event in self.sync_subscribers:
if context is None:
context = self.get_context(
event, product_new, product_old, **kwargs
)
for sync_sub in self.sync_subscribers[event]:
self.runner.run_sync(
sync_sub.handler,
context,
run_safe=sync_sub.exec_params.call_safe,
done_callback=sync_sub.exec_params.done_callback
)
[docs]
class ProductCacheUpdater:
"""Product cache manager. Updates the cache and notifies subscribers
about product updates.
This class is responsible for maintaining a cache of product data and
using an event publisher to notify interested parties of changes to
product information. It provides methods to create, delete, and update
products in the cache and to trigger appropriate events for each action.
"""
def __init__(
self,
products: dict[int, Product],
publisher: ProductUpdateEventPublisher
) -> None:
"""Initialize a `ProductCacheUpdater` instance.
Args:
products (dict[int, Product]): A dictionary of products, where
the keys are product IDs and the values are `Product` objects.
publisher (ProductUpdateEventPublisher): An instance of
the `ProductUpdateEventPublisher` class to call subscribed
handlers when product updates occur.
"""
self.products = products
self._publisher = publisher
[docs]
def create_product(self, product: Product, **kwargs: Any) -> None:
"""Add a new product to the cache and post a `PRODUCT_ADDED` event.
Note that this method assumes that the product is not present in the
cache already.
Args:
product (Product): the product to add to the cache.
**kwargs (Any): Additional keyword arguments to include
in the update event context.
"""
self.products[product.id_] = product
self._publisher.post(
event=ProductUpdateEvent.PRODUCT_ADDED,
product_new=product,
product_old=None,
**kwargs
)
[docs]
def delete_product(self, product: Product, **kwargs: Any) -> None:
"""Remove a product from the cache and posts a `PRODUCT_REMOVED` event.
Note that this method assumes that the product is present in the cache.
Args:
product (Product): the product to remove from the cache.
**kwargs (Any): Additional keyword arguments to include
in the update event context.
"""
del self.products[product.id_]
self._publisher.post(
event=ProductUpdateEvent.PRODUCT_REMOVED,
product_new=None,
product_old=product,
**kwargs
)
[docs]
def update_product(
self,
product: Product,
product_cached: Product,
diff: dict[str, DiffValueTuple],
**kwargs: Any
) -> None:
"""Update a product in the cache and post events reflecting
the specific changes occured.
This method determines what aspects of the product have changed
(e.g., stock level, price, etc.) and posts the appropriate events
for each type of update.
Note that this method assumes that the product is different from
the cached product.
Args:
product (Product): The new product state
after the update.
product_cached (Product): The original product state
before the update.
**kwargs (Any): Additional keyword arguments to include
in the update event context.
"""
self.products[product.id_] = product
events: set[ProductUpdateEvent] = {ProductUpdateEvent.PRODUCT_UPDATED}
if 'quantity' in diff:
events.add(ProductUpdateEvent.QUANTITY_UPDATED)
del diff['quantity']
if 'price_full' in diff:
events.add(ProductUpdateEvent.PRICE_UPDATED)
del diff['price_full']
if 'price_curr' in diff:
events.add(ProductUpdateEvent.PRICE_UPDATED)
del diff['price_curr']
if diff: # some other changes are present in the diff
events.add(ProductUpdateEvent.OTHER_UPDATED)
for event in events:
self._publisher.post(
event=event,
product_new=product,
product_old=product_cached,
**kwargs
)
[docs]
async def update(
self,
products: Iterable[Product],
await_handlers: bool = False,
**kwargs: Any
) -> None:
"""Processe a batch of product updates, adding, deleting, or updating
products as necessary.
This method iterates through a given iterable of products, updating
the cache to reflect new, removed, or changed products. It optionally
waits for all event handlers to complete if `await_handlers` is True.
Args:
products (Iterable[Product]): products to process.
await_handlers: If True, waits for all triggered event handlers
to complete.
**kwargs (Any): Additional keyword arguments to include
in the update event context.
"""
products_mapping = {p.id_: p for p in products}
# check if any product is not listed anymore
for id_number, product in self.products.copy().items():
if id_number not in products_mapping:
self.delete_product(product, **kwargs)
# update the products (or create if not previously listed)
for id_number, product in products_mapping.items():
product_cached = self.products.get(id_number)
if product_cached is None:
self.create_product(product, **kwargs)
else:
diff = product_cached.diff(product, exclude={'timestamp'})
if diff:
self.update_product(
product, product_cached, diff, **kwargs
)
# optionally await the triggered handlers
if await_handlers:
await self._publisher.runner.await_all()
[docs]
def update_silently(self, products: Iterable[Product]) -> None:
"""Update the product cache without triggering any events.
This method clears the cache dictionary and subsequently populates
it with the given products, effectively replacing the entire cache
contents while maintaining the dictionary object itself.
Args:
products (Iterable[Product]): products to populate the cache with.
"""
self.products.clear()
self.products.update({p.id_: p for p in products})