Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/aleph/api_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ async def configure_aiohttp_app(
session_factory = make_session_factory(engine)

node_cache = NodeCache(
redis_host=config.redis.host.value, redis_port=config.redis.port.value
redis_host=config.redis.host.value,
redis_port=config.redis.port.value,
message_count_cache_ttl=config.perf.message_count_cache_ttl.value,
)
# TODO: find a way to close the node cache when exiting the API process, not closing it causes
# a warning.
Expand Down
4 changes: 3 additions & 1 deletion src/aleph/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def run_db_migrations(config: Config):

async def init_node_cache(config: Config) -> NodeCache:
node_cache = NodeCache(
redis_host=config.redis.host.value, redis_port=config.redis.port.value
redis_host=config.redis.host.value,
redis_port=config.redis.port.value,
message_count_cache_ttl=config.perf.message_count_cache_ttl.value,
)
return node_cache

Expand Down
4 changes: 4 additions & 0 deletions src/aleph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def get_defaults():
# Sentry trace sample rate.
"traces_sample_rate": None,
},
"perf": {
# TTL of the cache in front of DB count queries on the messages table.
"message_count_cache_ttl": 300,
},
}


Expand Down
4 changes: 3 additions & 1 deletion src/aleph/jobs/fetch_pending_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ async def fetch_messages_task(config: Config):

async with (
NodeCache(
redis_host=config.redis.host.value, redis_port=config.redis.port.value
redis_host=config.redis.host.value,
redis_port=config.redis.port.value,
message_count_cache_ttl=config.perf.message_count_cache_ttl.value,
) as node_cache,
IpfsService.new(config) as ipfs_service,
):
Expand Down
4 changes: 3 additions & 1 deletion src/aleph/jobs/process_pending_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ async def fetch_and_process_messages_task(config: Config):

async with (
NodeCache(
redis_host=config.redis.host.value, redis_port=config.redis.port.value
redis_host=config.redis.host.value,
redis_port=config.redis.port.value,
message_count_cache_ttl=config.perf.message_count_cache_ttl.value,
) as node_cache,
IpfsService.new(config) as ipfs_service,
):
Expand Down
4 changes: 3 additions & 1 deletion src/aleph/jobs/process_pending_txs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ async def handle_txs_task(config: Config):

async with (
NodeCache(
redis_host=config.redis.host.value, redis_port=config.redis.port.value
redis_host=config.redis.host.value,
redis_port=config.redis.port.value,
message_count_cache_ttl=config.perf.message_count_cache_ttl.value,
) as node_cache,
IpfsService.new(config) as ipfs_service,
):
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/schemas/api/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from aleph_message.models import Chain
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, field_validator

from aleph.schemas.messages_query_params import DEFAULT_PAGE, LIST_FIELD_SEPARATOR
from aleph.types.files import FileType
from aleph.types.sort_order import SortOrder
from aleph.web.controllers.utils import DEFAULT_PAGE, LIST_FIELD_SEPARATOR


class GetAccountQueryParams(BaseModel):
Expand Down
208 changes: 208 additions & 0 deletions src/aleph/schemas/messages_query_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from typing import List, Optional

from aleph_message.models import Chain, ItemHash, MessageType
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from aleph.types.message_status import MessageStatus
from aleph.types.sort_order import SortBy, SortOrder

DEFAULT_WS_HISTORY = 10
DEFAULT_MESSAGES_PER_PAGE = 20
DEFAULT_PAGE = 1
LIST_FIELD_SEPARATOR = ","


class BaseMessageQueryParams(BaseModel):
sort_by: SortBy = Field(
default=SortBy.TIME,
alias="sortBy",
description="Key to use to sort the messages. "
"'time' uses the message time field. "
"'tx-time' uses the first on-chain confirmation time.",
)
sort_order: SortOrder = Field(
default=SortOrder.DESCENDING,
alias="sortOrder",
description="Order in which messages should be listed: "
"-1 means most recent messages first, 1 means older messages first.",
)
message_type: Optional[MessageType] = Field(
default=None,
alias="msgType",
description="Message type. Deprecated: use msgTypes instead",
)
message_types: Optional[List[MessageType]] = Field(
default=None, alias="msgTypes", description="Accepted message types."
)
message_statuses: Optional[List[MessageStatus]] = Field(
default=[MessageStatus.PROCESSED, MessageStatus.REMOVING],
alias="msgStatuses",
description="Accepted values for the 'status' field.",
)
addresses: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'sender' field."
)
refs: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'content.ref' field."
)
content_hashes: Optional[List[ItemHash]] = Field(
default=None,
alias="contentHashes",
description="Accepted values for the 'content.item_hash' field.",
)
content_keys: Optional[List[ItemHash]] = Field(
default=None,
alias="contentKeys",
description="Accepted values for the 'content.keys' field.",
)
content_types: Optional[List[str]] = Field(
default=None,
alias="contentTypes",
description="Accepted values for the 'content.type' field.",
)
chains: Optional[List[Chain]] = Field(
default=None, description="Accepted values for the 'chain' field."
)
channels: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'channel' field."
)
tags: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'content.content.tag' field."
)
hashes: Optional[List[ItemHash]] = Field(
default=None, description="Accepted values for the 'item_hash' field."
)

start_date: float = Field(
default=0,
ge=0,
alias="startDate",
description="Start date timestamp. If specified, only messages with "
"a time field greater or equal to this value will be returned.",
)
end_date: float = Field(
default=0,
ge=0,
alias="endDate",
description="End date timestamp. If specified, only messages with "
"a time field lower than this value will be returned.",
)

start_block: int = Field(
default=0,
ge=0,
alias="startBlock",
description="Start block number. If specified, only messages with "
"a block number greater or equal to this value will be returned.",
)
end_block: int = Field(
default=0,
ge=0,
alias="endBlock",
description="End block number. If specified, only messages with "
"a block number lower than this value will be returned.",
)

@model_validator(mode="after")
def validate_field_dependencies(self):
start_date = self.start_date
end_date = self.end_date
if start_date and end_date and (end_date < start_date):
raise ValueError("end date cannot be lower than start date.")
start_block = self.start_block
end_block = self.end_block
if start_block and end_block and (end_block < start_block):
raise ValueError("end block cannot be lower than start block.")

return self

@field_validator(
"hashes",
"addresses",
"refs",
"content_hashes",
"content_keys",
"content_types",
"chains",
"channels",
"message_types",
"message_statuses",
"tags",
mode="before",
)
def split_str(cls, v):
if isinstance(v, str):
return v.split(LIST_FIELD_SEPARATOR)
return v

model_config = ConfigDict(populate_by_name=True)


class MessageQueryParams(BaseMessageQueryParams):
pagination: int = Field(
default=DEFAULT_MESSAGES_PER_PAGE,
ge=0,
description="Maximum number of messages to return. Specifying 0 removes this limit.",
)
page: int = Field(
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
)


class WsMessageQueryParams(BaseMessageQueryParams):
history: Optional[int] = Field(
DEFAULT_WS_HISTORY,
ge=0,
lt=200,
description="Historical elements to send through the websocket.",
)


class MessageHashesQueryParams(BaseModel):
status: Optional[MessageStatus] = Field(
default=None,
description="Message status.",
)
page: int = Field(
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
)
pagination: int = Field(
default=DEFAULT_MESSAGES_PER_PAGE,
ge=0,
description="Maximum number of messages to return. Specifying 0 removes this limit.",
)
start_date: float = Field(
default=0,
ge=0,
alias="startDate",
description="Start date timestamp. If specified, only messages with "
"a time field greater or equal to this value will be returned.",
)
end_date: float = Field(
default=0,
ge=0,
alias="endDate",
description="End date timestamp. If specified, only messages with "
"a time field lower than this value will be returned.",
)
sort_order: SortOrder = Field(
default=SortOrder.DESCENDING,
alias="sortOrder",
description="Order in which messages should be listed: "
"-1 means most recent messages first, 1 means older messages first.",
)
hash_only: bool = Field(
default=True,
description="By default, only hashes are returned. "
"Set this to false to include metadata alongside the hashes in the response.",
)

@model_validator(mode="after")
def validate_field_dependencies(self):
start_date = self.start_date
end_date = self.end_date
if start_date and end_date and (end_date < start_date):
raise ValueError("end date cannot be lower than start date.")
return self

model_config = ConfigDict(populate_by_name=True)
37 changes: 33 additions & 4 deletions src/aleph/services/cache/node_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Any, List, Optional, Set
from hashlib import sha256
from typing import Any, Dict, List, Optional, Set

import redis.asyncio as redis_asyncio

import aleph.toolkit.json as aleph_json
from aleph.db.accessors.messages import count_matching_messages
from aleph.schemas.messages_query_params import MessageQueryParams
from aleph.types.db_session import DbSession

CacheKey = Any
CacheValue = bytes

Expand All @@ -10,9 +16,10 @@ class NodeCache:
API_SERVERS_KEY = "api_servers"
PUBLIC_ADDRESSES_KEY = "public_addresses"

def __init__(self, redis_host: str, redis_port: int):
def __init__(self, redis_host: str, redis_port: int, message_count_cache_ttl):
self.redis_host = redis_host
self.redis_port = redis_port
self.message_cache_count_ttl = message_count_cache_ttl

self._redis_client: Optional[redis_asyncio.Redis] = None

Expand Down Expand Up @@ -52,8 +59,8 @@ async def reset(self):
async def get(self, key: CacheKey) -> Optional[CacheValue]:
return await self.redis_client.get(key)

async def set(self, key: CacheKey, value: Any):
await self.redis_client.set(key, value)
async def set(self, key: CacheKey, value: Any, expiration: Optional[int] = None):
await self.redis_client.set(key, value, ex=expiration)

async def incr(self, key: CacheKey):
await self.redis_client.incr(key)
Expand Down Expand Up @@ -82,3 +89,25 @@ async def add_public_address(self, public_address: str) -> None:
async def get_public_addresses(self) -> List[str]:
addresses = await self.redis_client.smembers(self.PUBLIC_ADDRESSES_KEY)
return [addr.decode() for addr in addresses]

@staticmethod
def _message_filter_id(filters: Dict[str, Any]):
filters_json = aleph_json.dumps(filters, sort_keys=True)
return sha256(filters_json).hexdigest()

async def count_messages(
self, session: DbSession, query_params: MessageQueryParams
) -> int:
filters = query_params.model_dump(exclude_none=True)
cache_key = f"message_count:{self._message_filter_id(filters)}"

cached_result = await self.get(cache_key)
if cached_result is not None:
return int(cached_result.decode())

# Slow, can take a few seconds
n_matches = count_matching_messages(session, **filters)

await self.set(cache_key, n_matches, expiration=self.message_cache_count_ttl)

return n_matches
10 changes: 6 additions & 4 deletions src/aleph/toolkit/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# serializer changes easier.
SerializedJsonInput = Union[bytes, str]


# Note: JSONDecodeError is a subclass of ValueError, but the JSON module sometimes throws
# raw value errors, including on NaN because of our custom parse_constant.
DecodeError = orjson.JSONDecodeError
Expand Down Expand Up @@ -55,8 +54,11 @@ def extended_json_encoder(obj: Any) -> Any:
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")


def dumps(obj: Any) -> bytes:
def dumps(obj: Any, sort_keys: bool = True) -> bytes:
try:
return orjson.dumps(obj)
opts = orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS if sort_keys else 0
return orjson.dumps(obj, option=opts)
except TypeError:
return json.dumps(obj, default=extended_json_encoder).encode()
return json.dumps(
obj, default=extended_json_encoder, sort_keys=sort_keys
).encode()
3 changes: 1 addition & 2 deletions src/aleph/web/controllers/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

from aleph.db.accessors.aggregates import get_aggregates_by_owner, refresh_aggregate
from aleph.db.models import AggregateDb

from .utils import LIST_FIELD_SEPARATOR
from aleph.schemas.messages_query_params import LIST_FIELD_SEPARATOR

LOGGER = logging.getLogger(__name__)

Expand Down
Loading
Loading