From 186b9c1e254113a7bf20cf0e48ca33b03a63ba5a Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 02:49:44 +0000 Subject: [PATCH 01/11] added message stat functions to user and guild classes --- killua/utils/classes/guild.py | 52 +++++++++++++++++++++++++++++++++++ killua/utils/classes/user.py | 18 +++++++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/killua/utils/classes/guild.py b/killua/utils/classes/guild.py index 599df6708..c2a389f4f 100644 --- a/killua/utils/classes/guild.py +++ b/killua/utils/classes/guild.py @@ -136,3 +136,55 @@ async def update_poll_votes(self, id: int, updated: dict) -> None: """Updates the votes of a poll""" self.polls[str(id)]["votes"] = updated await self._update_val(f"polls.{id}.votes", updated) + + async def get_top_senders(self, limit: int = 10) -> List[tuple[int, int]]: + """Fetch the top message senders in the guild. Returns a list of (user_id, message_count) tuples""" + guild_id = str(self.id) + pipeline = [ + {"$match": {f"message_stats.{guild_id}": {"$exists": True}}}, + { + "$project": { + "id": 1, + "message_count": f"$message_stats.{guild_id}", + } + }, + {"$sort": {"message_count": -1}}, + {"$limit": limit}, + ] + + cursor = await DB.teams.aggregate(pipeline) + result = await cursor.to_list(length=limit) + return [(doc["id"], doc["message_count"]) for doc in result] + + async def get_total_messages(self) -> int: + """Gets the total number of messages sent in this guild""" + guild_id = str(self.id) + pipeline = [ + {"$match": {f"message_stats.{guild_id}": {"$exists": True}}}, + { + "$group": { + "_id": None, + "total_messages": {"$sum": f"$message_stats.{guild_id}"}, + } + }, + ] + + cursor = await DB.teams.aggregate(pipeline) + result = await cursor.to_list(length=1) + if result: + return result[0]["total_messages"] + return 0 + + async def get_user_rank(self, user_id: int) -> Optional[int]: + """Gets the rank of a user in this guild based on message count""" + guild_id = str(self.id) + user = await User.new(user_id) + user_message_count = user.message_stats.get(guild_id, 0) + + if user_message_count == 0: + return None + + rank = await DB.teams.count_documents( + {f"message_stats.{guild_id}": {"$gt": user_message_count}} + ) + 1 + return rank + 1 \ No newline at end of file diff --git a/killua/utils/classes/user.py b/killua/utils/classes/user.py index 4decbb1f9..da9268491 100644 --- a/killua/utils/classes/user.py +++ b/killua/utils/classes/user.py @@ -2,7 +2,9 @@ from datetime import datetime, timedelta from typing import Any, ClassVar, Dict, List, Optional, Union, cast, Literal, Tuple -from dataclasses import dataclass +from dataclasses import dataclass\ + +import logging from killua.static.constants import ( DB, @@ -43,6 +45,7 @@ class User: email: Optional[str] email_notifications: Dict[Literal["news", "updates", "posts"], bool] cache: ClassVar[Dict[int, User]] = {} + message_stats: Dict[int, int] # guild_id, message_count async def set_email(self, email: str) -> None: """Sets the user's email address""" @@ -108,6 +111,7 @@ async def new(cls, user_id: int): "email_notifications", {"news": False, "updates": False, "posts": False}, ), + message_stats=data.get("message_stats", {}), ) cls.cache[user_id] = instance @@ -253,6 +257,7 @@ async def add_empty(cls, user_id: int, cards: bool = True) -> None: "updates": False, "posts": False, }, + "message_stats": {}, } ) @@ -842,3 +847,14 @@ async def register_login(self) -> bool: self.achievements.append("logged_into_website") await self._update_val("achievements", self.achievements) return True + + async def increment_message_count(self, guild_id: int, amount: int = 1) -> None: + """Increments the message count for this user in a specific guild""" + guild_id_str = str(guild_id) + current_count = self.message_stats.get(guild_id, 0) + self.message_stats[guild_id] = current_count + amount + await self._update_val(f"message_stats.{guild_id_str}", self.message_stats[guild_id]) + + async def get_message_count(self, guild_id: int) -> int: + """Gets the message count for this user in a specific guild""" + return self.message_stats.get(guild_id, 0) \ No newline at end of file From 19b35afb869ecd97844bd34045af54eb489c4a11 Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 02:49:57 +0000 Subject: [PATCH 02/11] added migration to set message count to 0 --- killua/migrate.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/killua/migrate.py b/killua/migrate.py index 981b7f13b..b09f53114 100644 --- a/killua/migrate.py +++ b/killua/migrate.py @@ -56,6 +56,13 @@ async def migrate(): logging.info("Migrated user achievements key to achievements successfully") + await DB.teams.update_many( + {"message_stats": {"$exists": False}}, + {"$set": {"message_stats": {}}} + ) + + logging.info("Added message_stats field to all users successfully") + await DB.const.update_one( {"_id": "migrate"}, {"$set": {"value": True}}, From a14c75d849d1c1cc24317e610f03ce6aa0cf86ff Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 02:50:17 +0000 Subject: [PATCH 03/11] added event listener to update message count on message send --- killua/cogs/events.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/killua/cogs/events.py b/killua/cogs/events.py index a7f23f153..855c10814 100644 --- a/killua/cogs/events.py +++ b/killua/cogs/events.py @@ -1140,5 +1140,25 @@ async def on_command_error(self, ctx: commands.Context, error): pass # This theoretically should be covered by all the cases above, # but handling it again here can't hurt + @commands.Cog.listener() + async def on_message(self,message: discord.Message): + # ignore bot messages + if message.author.bot: + return + + # do not track DMs + if not message.guild: + return + + # ignore system messages + if message.type != discord.MessageType.default: + return + + try: + user = await User.new(message.author.id) + await user.increment_message_count(message.guild.id) + logging.info(f"Incremented message count for user {message.author.id} in guild {message.guild.id}") + except Exception as e: + logging.error(f"Failed to increment message count for user {message.author.id} in guild {message.guild.id}: {e}") Cog = Events From 3dad3df4b7d468af65a833c77acc8177714b6415 Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 02:50:39 +0000 Subject: [PATCH 04/11] added message cog with stats command --- killua/cogs/message.py | 144 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 killua/cogs/message.py diff --git a/killua/cogs/message.py b/killua/cogs/message.py new file mode 100644 index 000000000..10ed54675 --- /dev/null +++ b/killua/cogs/message.py @@ -0,0 +1,144 @@ +import discord +from discord import app_commands +from discord.ext import commands +from typing import Optional + +from killua.bot import BaseBot +from killua.utils.classes.guild import Guild +from killua.utils.classes.user import User + +TRACKING_SINCE = "2025-12-21" # The date message tracking was added MUST BE UPDATED IF DEPLOYED + +class Message(commands.GroupCog, group_name="message"): + """Cog to handle stats commands""" + + def __init__(self, client: BaseBot): + self.client = client + + @app_commands.command(name="stats", description="Get message stats in this guild") + @app_commands.describe( + user="View stats for a specific user", + limit="Number of top users to show (default 10)", + ) + @app_commands.guild_only() + @app_commands.checks.cooldown(1, 5.0) + async def stats( + self, + interaction: discord.Interaction, + user: Optional[discord.Member] = None, + limit: Optional[int] = 10 + ): + """View message stats for a user or top users in the guild.""" + await interaction.response.defer() + + guild = await Guild.new(interaction.guild.id) + + if user: + await self._show_user_stats(interaction, guild, user) + else: + await self._show_leaderboard(interaction, guild, limit) + + + async def _show_user_stats( + self, + interaction: discord.Interaction, + guild: Guild, + member: discord.Member + ): + """Display stats for a specific user""" + user = await User.new(member.id) + message_count = await user.get_message_count(guild.id) + rank = await guild.get_user_rank(member.id) + total_messages = await guild.get_total_messages() + + embed = discord.Embed( + title="📊 Message Stats", + description=f"Stats for {member.mention} in **{interaction.guild.name}**", + color=discord.Color.blue() + ) + + if message_count == 0: + embed.add_field( + name="No Messages", + value=f"{member.mention} has not sent any messages in this guild.", + inline=False + ) + else: + percentage = (message_count / total_messages) * 100 if total_messages > 0 else 0 + + embed.add_field(name="Messages Sent", value=f"{message_count:,}", inline=True) + embed.add_field(name="Rank", value=f"#{rank}", inline=True) + embed.add_field(name="Percentage of Total Messages", value=f"{percentage:.2f}%", inline=True) + + embed.set_footer(text=f"Tracking since {TRACKING_SINCE} • Requested by {interaction.user.display_name}", icon_url=interaction.user.display_avatar.url) + await interaction.followup.send(embed=embed) + + async def _show_leaderboard( + self, + interaction: discord.Interaction, + guild: Guild, + limit: int + ): + """Display the message leaderboard for the guild""" + + if limit < 1 or limit > 25: + await interaction.followup.send("❌ Limit must be between 1 and 25.", ephemeral=True) + return + + top_senders = await guild.get_top_senders(limit) + total_messages = await guild.get_total_messages() + + embed = discord.Embed( + title="📊 Message Leaderboard", + description=f"Top {limit} message senders in **{interaction.guild.name}**", + color=discord.Color.blue() + ) + + if not top_senders: + embed.add_field( + name="No Data", + value="No message data available for this guild.", + inline=False + ) + else: + leaderboard = "" + for rank, (user_id, message_count) in enumerate(top_senders, start=1): + member = interaction.guild.get_member(user_id) + if member is None: + try: + member = await interaction.guild.fetch_member(user_id) + except (discord.NotFound, discord.HTTPException): + pass + member_name = member.display_name if member else f"User ID {user_id}" + percentage = (message_count / total_messages) * 100 if total_messages > 0 else 0 + + medal = f"#{rank}" + if rank == 1: + medal = "🥇 " + elif rank == 2: + medal = "🥈 " + elif rank == 3: + medal = "🥉 " + + leaderboard += f"**{medal}**- {member_name}: {message_count:,} messages ({percentage:.2f}%)\n" + + embed.add_field(name="Leaderboard", value=leaderboard, inline=False) + + embed.set_footer(text="Tracking since {TRACKING_SINCE} • Requested by " + interaction.user.display_name, icon_url=interaction.user.display_avatar.url) + await interaction.followup.send(embed=embed) + + @stats.error + async def stats_error( + self, + interaction: discord.Interaction, + error: app_commands.AppCommandError + ): + """Error handler for the stats command""" + if isinstance(error, app_commands.CommandOnCooldown): + await interaction.response.send_message( + f"⏳ This command is on cooldown. Please try again in {error.retry_after:.1f} seconds.", + ephemeral=True + ) + else: + raise error +Cog = Message \ No newline at end of file From a3469c34ceedd27bf6a1cc37e654513286799840 Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 02:51:04 +0000 Subject: [PATCH 05/11] added TestingMessage class and corresponding tests + updated tests list in __init__.py --- killua/tests/groups/__init__.py | 3 +- killua/tests/groups/message.py | 131 ++++++++++++++++++++++++++++++++ killua/tests/testing.py | 18 +++-- 3 files changed, 145 insertions(+), 7 deletions(-) create mode 100644 killua/tests/groups/message.py diff --git a/killua/tests/groups/__init__.py b/killua/tests/groups/__init__.py index 800932907..bede26b06 100644 --- a/killua/tests/groups/__init__.py +++ b/killua/tests/groups/__init__.py @@ -1,7 +1,8 @@ from .actions import TestingActions from .cards import TestingCards from .dev import TestingDev +from .message import TestingMessage -tests = [TestingActions, TestingCards, TestingDev] +tests = [TestingActions, TestingCards, TestingDev, TestingMessage] __all__ = ["tests"] diff --git a/killua/tests/groups/message.py b/killua/tests/groups/message.py new file mode 100644 index 000000000..a2089d0cd --- /dev/null +++ b/killua/tests/groups/message.py @@ -0,0 +1,131 @@ +from unittest.mock import AsyncMock, MagicMock, patch +from ..types import * +from ...utils.classes import * +from ..testing import Testing, test +from ...cogs.message import Message + +class TestingMessage(Testing): + def __init__(self): + super().__init__(cog=Message) + +class Stats(TestingMessage): + def __init__(self): + super().__init__() + + @test + async def leaderboard_renders(self) -> None: + # mock member + member = MagicMock() + member.id = 42 + member.mention = "<@42>" + member.display_name = "TestUser" + + # mock interaction + interaction = MagicMock() + interaction.guild.id = 123 + interaction.guild.name = "Test Guild" + interaction.guild.get_member = MagicMock(return_value=member) + interaction.guild.fetch_member = AsyncMock(side_effect=Exception()) + interaction.followup.send = AsyncMock() + interaction.response.defer = AsyncMock() + interaction.user.display_name = "Tester" + + # mock guild + mock_guild = MagicMock() + mock_guild.id = 123 + mock_guild.get_top_senders = AsyncMock(return_value=[(1, 100), (2, 80), (3, 60)]) + mock_guild.get_total_messages = AsyncMock(return_value=240) + + with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): + await self.cog.stats.callback(self.cog, interaction, user=None, limit=3) + + interaction.followup.send.assert_awaited() + sent_embed = interaction.followup.send.call_args.kwargs.get("embed") + assert sent_embed is not None, "Expected an embed to be sent" + + @test + async def user_with_no_messages(self) -> None: + # mock interaction + interaction = MagicMock() + interaction.guild.id = 123 + interaction.user.display_name = "Tester" + interaction.followup.send = AsyncMock() + interaction.response.defer = AsyncMock() + + # mock member with no messages + member = MagicMock() + member.id = 42 + member.mention = "<@42>" + member.display_name = "TestUser" + member.get_message_count = AsyncMock(return_value=0) + + # mock guild + mock_guild = MagicMock() + mock_guild.id = 123 + mock_guild.get_user_rank = AsyncMock(return_value=None) + mock_guild.get_total_messages = AsyncMock(return_value=0) + + with patch("killua.utils.classes.user.User.new", AsyncMock(return_value=member)), \ + patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): + await self.cog.stats.callback(self.cog, interaction, user=member, limit=10) + + interaction.followup.send.assert_awaited() + sent_embed = interaction.followup.send.call_args.kwargs.get("embed") + assert sent_embed is not None and "No Messages" in sent_embed.fields[0].name, "Expected 'No Messages' field in embed" + + @test + async def user_with_messages(self) -> None: + # mock interaction + interaction = MagicMock() + interaction.guild.id = 123 + interaction.user.display_name = "Tester" + interaction.followup.send = AsyncMock() + interaction.response.defer = AsyncMock() + + # mock member with messages + member = MagicMock() + member.id = 42 + member.mention = "<@42>" + member.display_name = "TestUser" + member.get_message_count = AsyncMock(return_value=50) + + # mock guild + mock_guild = MagicMock() + mock_guild.id = 123 + mock_guild.get_user_rank = AsyncMock(return_value=5) + mock_guild.get_total_messages = AsyncMock(return_value=200) + + with patch("killua.utils.classes.user.User.new", AsyncMock(return_value=member)), \ + patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): + await self.cog.stats.callback(self.cog, interaction, user=member, limit=10) + + interaction.followup.send.assert_awaited() + sent_embed = interaction.followup.send.call_args.kwargs.get("embed") + assert sent_embed is not None, "Expected an embed to be sent" + field_names = [field.name for field in sent_embed.fields] + assert "Messages Sent" in field_names, "Expected 'Messages Sent' field in embed" + assert "Rank" in field_names, "Expected 'Rank' field in embed" + assert "Percentage of Total Messages" in field_names, "Expected 'Percentage of Total Messages' field in embed" + + @test + async def leaderboard_no_messages(self) -> None: + # mock interaction + interaction = MagicMock() + interaction.guild.id = 123 + interaction.user.display_name = "Tester" + interaction.followup.send = AsyncMock() + interaction.response.defer = AsyncMock() + + # mock guild with no messages + mock_guild = MagicMock() + mock_guild.id = 123 + mock_guild.get_top_senders = AsyncMock(return_value=[]) + mock_guild.get_total_messages = AsyncMock(return_value=0) + + with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): + await self.cog.stats.callback(self.cog, interaction, user=None, limit=10) + + interaction.followup.send.assert_awaited() + sent_embed = interaction.followup.send.call_args.kwargs.get("embed") + assert sent_embed is not None, "Expected an embed to be sent" + assert "No message data available" in sent_embed.fields[0].value, "Expected 'No message data available' message in embed" \ No newline at end of file diff --git a/killua/tests/testing.py b/killua/tests/testing.py index 09bd75d2e..3a8f9ee7d 100644 --- a/killua/tests/testing.py +++ b/killua/tests/testing.py @@ -50,12 +50,18 @@ def __init__(self, cog: Cog): def all_tests(self) -> List[Testing]: """Automatically checks what functions are test based on their name and the overlap with the Cog method names""" cog_methods = [] - for cmd in [(command.name, command) for command in self.cog.get_commands()]: - if hasattr(cmd[1], "walk_commands") and cmd[1].walk_commands(): - for child in cmd[1].walk_commands(): - cog_methods.append((child.name, child)) - else: - cog_methods.append(cmd) + + # handle app commands and normal commands + if hasattr(self.cog, '__cog_app_commands__'): + for cmd in self.cog.__cog_app_commands__: + cog_methods.append((cmd.name, cmd)) + else: + for cmd in [(command.name, command) for command in self.cog.get_commands()]: + if hasattr(cmd[1], "walk_commands") and cmd[1].walk_commands(): + for child in cmd[1].walk_commands(): + cog_methods.append((child.name, child)) + else: + cog_methods.append(cmd) command_classes: List[Testing] = [] From 964e4fd844f1e426f1de37634b085d312f806d08 Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 03:10:52 +0000 Subject: [PATCH 06/11] removed log statement for message incrementation --- killua/cogs/events.py | 1 - 1 file changed, 1 deletion(-) diff --git a/killua/cogs/events.py b/killua/cogs/events.py index 855c10814..cdd9dcc1a 100644 --- a/killua/cogs/events.py +++ b/killua/cogs/events.py @@ -1157,7 +1157,6 @@ async def on_message(self,message: discord.Message): try: user = await User.new(message.author.id) await user.increment_message_count(message.guild.id) - logging.info(f"Incremented message count for user {message.author.id} in guild {message.guild.id}") except Exception as e: logging.error(f"Failed to increment message count for user {message.author.id} in guild {message.guild.id}: {e}") From 24e5215804309d267200f89311216ca78ca72ee7 Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 17:52:21 +0000 Subject: [PATCH 07/11] moved message count from user to guilds + default to opt-out of mesage tracking (at both user and guild level) --- killua/cogs/events.py | 9 ++++- killua/cogs/message.py | 3 +- killua/migrate.py | 6 +-- killua/tests/groups/message.py | 14 +++---- killua/utils/classes/guild.py | 74 +++++++++++++++------------------- killua/utils/classes/user.py | 21 ++++------ 6 files changed, 59 insertions(+), 68 deletions(-) diff --git a/killua/cogs/events.py b/killua/cogs/events.py index cdd9dcc1a..1e3b93f72 100644 --- a/killua/cogs/events.py +++ b/killua/cogs/events.py @@ -1155,8 +1155,15 @@ async def on_message(self,message: discord.Message): return try: + guild = await Guild.new(message.guild.id) + if not guild.message_tracking_enabled: + return # guild has opted out + user = await User.new(message.author.id) - await user.increment_message_count(message.guild.id) + if not user.message_tracking_enabled: + return # user has opted out + + await guild.increment_message_count(message.author.id) except Exception as e: logging.error(f"Failed to increment message count for user {message.author.id} in guild {message.guild.id}: {e}") diff --git a/killua/cogs/message.py b/killua/cogs/message.py index 10ed54675..487b481cc 100644 --- a/killua/cogs/message.py +++ b/killua/cogs/message.py @@ -46,8 +46,7 @@ async def _show_user_stats( member: discord.Member ): """Display stats for a specific user""" - user = await User.new(member.id) - message_count = await user.get_message_count(guild.id) + message_count = guild.get_message_count(member.id) rank = await guild.get_user_rank(member.id) total_messages = await guild.get_total_messages() diff --git a/killua/migrate.py b/killua/migrate.py index b09f53114..cbfea679d 100644 --- a/killua/migrate.py +++ b/killua/migrate.py @@ -56,12 +56,12 @@ async def migrate(): logging.info("Migrated user achievements key to achievements successfully") - await DB.teams.update_many( + # Add message_stats field to all guilds + result = await DB.guilds.update_many( {"message_stats": {"$exists": False}}, {"$set": {"message_stats": {}}} ) - - logging.info("Added message_stats field to all users successfully") + logging.info(f"Added message_stats field to {result.modified_count} guilds") await DB.const.update_one( {"_id": "migrate"}, diff --git a/killua/tests/groups/message.py b/killua/tests/groups/message.py index a2089d0cd..98dc48fb6 100644 --- a/killua/tests/groups/message.py +++ b/killua/tests/groups/message.py @@ -57,16 +57,15 @@ async def user_with_no_messages(self) -> None: member.id = 42 member.mention = "<@42>" member.display_name = "TestUser" - member.get_message_count = AsyncMock(return_value=0) - # mock guild + # mock guild with no messages for this user mock_guild = MagicMock() mock_guild.id = 123 + mock_guild.get_message_count = MagicMock(return_value=0) mock_guild.get_user_rank = AsyncMock(return_value=None) mock_guild.get_total_messages = AsyncMock(return_value=0) - with patch("killua.utils.classes.user.User.new", AsyncMock(return_value=member)), \ - patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): + with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): await self.cog.stats.callback(self.cog, interaction, user=member, limit=10) interaction.followup.send.assert_awaited() @@ -87,16 +86,15 @@ async def user_with_messages(self) -> None: member.id = 42 member.mention = "<@42>" member.display_name = "TestUser" - member.get_message_count = AsyncMock(return_value=50) - # mock guild + # mock guild with messages for this user mock_guild = MagicMock() mock_guild.id = 123 + mock_guild.get_message_count = MagicMock(return_value=50) mock_guild.get_user_rank = AsyncMock(return_value=5) mock_guild.get_total_messages = AsyncMock(return_value=200) - with patch("killua.utils.classes.user.User.new", AsyncMock(return_value=member)), \ - patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): + with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): await self.cog.stats.callback(self.cog, interaction, user=member, limit=10) interaction.followup.send.assert_awaited() diff --git a/killua/utils/classes/guild.py b/killua/utils/classes/guild.py index c2a389f4f..fe18a05fe 100644 --- a/killua/utils/classes/guild.py +++ b/killua/utils/classes/guild.py @@ -22,6 +22,8 @@ class Guild: polls: dict = field(default_factory=dict) tags: List[dict] = field(default_factory=list) added_on: Optional[datetime] = None + message_stats: Dict[int, int] = field(default_factory=dict) # user_id: message_count + message_tracking_enabled: bool = False cache: ClassVar[Dict[int, Guild]] = {} @classmethod @@ -67,6 +69,11 @@ async def new(cls, guild_id: int, member_count: Optional[int] = None) -> Guild: raw["approximate_member_count"] = await cls._member_count_helper( guild_id, raw.get("approximate_member_count", None), member_count ) + # Convert message_stats string keys to int keys for in-memory use + message_stats_raw = raw.get("message_stats", {}) + raw["message_stats"] = {int(k): v for k, v in message_stats_raw.items()} if message_stats_raw else {} + raw["message_tracking_enabled"] = raw.get("message_tracking_enabled", False) + guild = cls.from_dict(raw) cls.cache[guild_id] = guild @@ -80,7 +87,7 @@ def is_premium(self) -> bool: async def add_default(cls, guild_id: int, member_count: Optional[int]) -> None: """Adds a guild to the database""" await DB.guilds.insert_one( - {"id": guild_id, "points": 0, "items": "", "badges": [], "prefix": "k!", "approximate_member_count": member_count or 0, "added_on": datetime.now()} + {"id": guild_id, "points": 0, "items": "", "badges": [], "prefix": "k!", "approximate_member_count": member_count or 0, "added_on": datetime.now(), "message_stats": {}, "message_tracking_enabled": False} ) @classmethod @@ -137,54 +144,39 @@ async def update_poll_votes(self, id: int, updated: dict) -> None: self.polls[str(id)]["votes"] = updated await self._update_val(f"polls.{id}.votes", updated) + async def increment_message_count(self, user_id: int, amount: int = 1) -> None: + """Increments the message count for a user in this guild""" + user_id_str = str(user_id) + current_count = self.message_stats.get(user_id, 0) + self.message_stats[user_id] = current_count + amount + await self._update_val(f"message_stats.{user_id_str}", amount, "$inc") + + def get_message_count(self, user_id: int) -> int: + """Gets the message count for a user in this guild""" + return self.message_stats.get(user_id, 0) + async def get_top_senders(self, limit: int = 10) -> List[tuple[int, int]]: """Fetch the top message senders in the guild. Returns a list of (user_id, message_count) tuples""" - guild_id = str(self.id) - pipeline = [ - {"$match": {f"message_stats.{guild_id}": {"$exists": True}}}, - { - "$project": { - "id": 1, - "message_count": f"$message_stats.{guild_id}", - } - }, - {"$sort": {"message_count": -1}}, - {"$limit": limit}, - ] - - cursor = await DB.teams.aggregate(pipeline) - result = await cursor.to_list(length=limit) - return [(doc["id"], doc["message_count"]) for doc in result] + # Sort the in-memory message_stats and return top N + sorted_stats = sorted(self.message_stats.items(), key=lambda x: x[1], reverse=True) + return sorted_stats[:limit] async def get_total_messages(self) -> int: """Gets the total number of messages sent in this guild""" - guild_id = str(self.id) - pipeline = [ - {"$match": {f"message_stats.{guild_id}": {"$exists": True}}}, - { - "$group": { - "_id": None, - "total_messages": {"$sum": f"$message_stats.{guild_id}"}, - } - }, - ] - - cursor = await DB.teams.aggregate(pipeline) - result = await cursor.to_list(length=1) - if result: - return result[0]["total_messages"] - return 0 + return sum(self.message_stats.values()) async def get_user_rank(self, user_id: int) -> Optional[int]: """Gets the rank of a user in this guild based on message count""" - guild_id = str(self.id) - user = await User.new(user_id) - user_message_count = user.message_stats.get(guild_id, 0) + user_count = self.message_stats.get(user_id, 0) - if user_message_count == 0: + if user_count == 0: return None - rank = await DB.teams.count_documents( - {f"message_stats.{guild_id}": {"$gt": user_message_count}} - ) + 1 - return rank + 1 \ No newline at end of file + rank = sum(1 for count in self.message_stats.values() if count > user_count) + return rank + 1 + + async def toggle_message_tracking(self) -> bool: + """Toggles message tracking for the guild""" + self.message_tracking_enabled = not self.message_tracking_enabled + await self._update_val("message_tracking_enabled", self.message_tracking_enabled) + return self.message_tracking_enabled \ No newline at end of file diff --git a/killua/utils/classes/user.py b/killua/utils/classes/user.py index da9268491..92c1aaefd 100644 --- a/killua/utils/classes/user.py +++ b/killua/utils/classes/user.py @@ -44,8 +44,8 @@ class User: has_user_installed: bool email: Optional[str] email_notifications: Dict[Literal["news", "updates", "posts"], bool] + message_tracking_enabled: bool = False cache: ClassVar[Dict[int, User]] = {} - message_stats: Dict[int, int] # guild_id, message_count async def set_email(self, email: str) -> None: """Sets the user's email address""" @@ -111,7 +111,7 @@ async def new(cls, user_id: int): "email_notifications", {"news": False, "updates": False, "posts": False}, ), - message_stats=data.get("message_stats", {}), + message_tracking_enabled=data.get("message_tracking_enabled", False), ) cls.cache[user_id] = instance @@ -257,7 +257,7 @@ async def add_empty(cls, user_id: int, cards: bool = True) -> None: "updates": False, "posts": False, }, - "message_stats": {}, + "message_tracking_enabled": False, } ) @@ -848,13 +848,8 @@ async def register_login(self) -> bool: await self._update_val("achievements", self.achievements) return True - async def increment_message_count(self, guild_id: int, amount: int = 1) -> None: - """Increments the message count for this user in a specific guild""" - guild_id_str = str(guild_id) - current_count = self.message_stats.get(guild_id, 0) - self.message_stats[guild_id] = current_count + amount - await self._update_val(f"message_stats.{guild_id_str}", self.message_stats[guild_id]) - - async def get_message_count(self, guild_id: int) -> int: - """Gets the message count for this user in a specific guild""" - return self.message_stats.get(guild_id, 0) \ No newline at end of file + async def toggle_message_tracking(self) -> bool: + """Toggles message tracking for the user""" + self.message_tracking_enabled = not self.message_tracking_enabled + await self._update_val("message_tracking_enabled", self.message_tracking_enabled) + return self.message_tracking_enabled \ No newline at end of file From 6c818ae5c80296fae5edd4fdfd4d0ba00ab4f400 Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 18:18:14 +0000 Subject: [PATCH 08/11] split leaderboard and user stats command, added migration for enabling tracking, updated tests to match --- killua/cogs/message.py | 59 +++++++++++++++++++++++++--- killua/migrate.py | 12 ++++++ killua/tests/groups/message.py | 72 ++++++++++++++++++---------------- 3 files changed, 103 insertions(+), 40 deletions(-) diff --git a/killua/cogs/message.py b/killua/cogs/message.py index 487b481cc..aaeb28ffc 100644 --- a/killua/cogs/message.py +++ b/killua/cogs/message.py @@ -15,18 +15,16 @@ class Message(commands.GroupCog, group_name="message"): def __init__(self, client: BaseBot): self.client = client - @app_commands.command(name="stats", description="Get message stats in this guild") + @app_commands.command(name="stats", description="Get message stats for a user in this guild") @app_commands.describe( user="View stats for a specific user", - limit="Number of top users to show (default 10)", ) @app_commands.guild_only() @app_commands.checks.cooldown(1, 5.0) async def stats( self, interaction: discord.Interaction, - user: Optional[discord.Member] = None, - limit: Optional[int] = 10 + user: discord.Member = None, ): """View message stats for a user or top users in the guild.""" await interaction.response.defer() @@ -36,9 +34,58 @@ async def stats( if user: await self._show_user_stats(interaction, guild, user) else: - await self._show_leaderboard(interaction, guild, limit) + await self._show_user_stats(interaction, guild, interaction.user) + + @app_commands.command(name="leaderboard", description="Show the message leaderboard for this guild") + @app_commands.describe( + limit="Number of top users to display (max 25)", + ) + @app_commands.guild_only() + @app_commands.checks.cooldown(1, 10.0) + async def leaderboard( + self, + interaction: discord.Interaction, + limit: int = 10, + ): + """Display the message leaderboard for this guild.""" + await interaction.response.defer() + + guild = await Guild.new(interaction.guild.id) + await self._show_leaderboard(interaction, guild, limit) + + @app_commands.command(name="server_tracking", description="Toggle message tracking for this server") + @app_commands.guild_only() + @app_commands.default_permissions(manage_guild=True) + async def server_tracking( + self, + interaction: discord.Interaction, + ): + """Toggle message tracking for this server""" + await interaction.response.defer(ephemeral=True) + + guild = await Guild.new(interaction.guild.id) + new_status = await guild.toggle_message_tracking() + + status_text = "enabled" if new_status else "disabled" + status_emoji = "✅" if new_status else "❌" + await interaction.followup.send(f"{status_emoji} Message tracking has been {status_text} for this server.", ephemeral=True) + + @app_commands.command(name="user_tracking", description="Toggle message tracking for your account") + @app_commands.checks.cooldown(1, 10.0) + async def user_tracking( + self, + interaction: discord.Interaction, + ): + """Toggle message tracking for your account""" + await interaction.response.defer(ephemeral=True) + + user = await User.new(interaction.user.id) + new_status = await user.toggle_message_tracking() + + status_text = "enabled" if new_status else "disabled" + status_emoji = "✅" if new_status else "❌" + await interaction.followup.send(f"{status_emoji} Message tracking has been {status_text} for your account.", ephemeral=True) - async def _show_user_stats( self, interaction: discord.Interaction, diff --git a/killua/migrate.py b/killua/migrate.py index cbfea679d..5fcd11f50 100644 --- a/killua/migrate.py +++ b/killua/migrate.py @@ -63,6 +63,18 @@ async def migrate(): ) logging.info(f"Added message_stats field to {result.modified_count} guilds") + result = await DB.guilds.update_many( + {"message_tracking_enabled": {"$exists": False}}, + {"$set": {"message_tracking_enabled": False}} + ) + logging.info(f"Added message_tracking_enabled field to {result.modified_count} guilds") + + result = await DB.teams.update_many( + {"message_tracking_enabled": {"$exists": False}}, + {"$set": {"message_tracking_enabled": False}} + ) + logging.info(f"Added message_tracking_enabled field to {result.modified_count} users") + await DB.const.update_one( {"_id": "migrate"}, {"$set": {"value": True}}, diff --git a/killua/tests/groups/message.py b/killua/tests/groups/message.py index 98dc48fb6..e72c65c5b 100644 --- a/killua/tests/groups/message.py +++ b/killua/tests/groups/message.py @@ -11,37 +11,6 @@ def __init__(self): class Stats(TestingMessage): def __init__(self): super().__init__() - - @test - async def leaderboard_renders(self) -> None: - # mock member - member = MagicMock() - member.id = 42 - member.mention = "<@42>" - member.display_name = "TestUser" - - # mock interaction - interaction = MagicMock() - interaction.guild.id = 123 - interaction.guild.name = "Test Guild" - interaction.guild.get_member = MagicMock(return_value=member) - interaction.guild.fetch_member = AsyncMock(side_effect=Exception()) - interaction.followup.send = AsyncMock() - interaction.response.defer = AsyncMock() - interaction.user.display_name = "Tester" - - # mock guild - mock_guild = MagicMock() - mock_guild.id = 123 - mock_guild.get_top_senders = AsyncMock(return_value=[(1, 100), (2, 80), (3, 60)]) - mock_guild.get_total_messages = AsyncMock(return_value=240) - - with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): - await self.cog.stats.callback(self.cog, interaction, user=None, limit=3) - - interaction.followup.send.assert_awaited() - sent_embed = interaction.followup.send.call_args.kwargs.get("embed") - assert sent_embed is not None, "Expected an embed to be sent" @test async def user_with_no_messages(self) -> None: @@ -66,7 +35,7 @@ async def user_with_no_messages(self) -> None: mock_guild.get_total_messages = AsyncMock(return_value=0) with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): - await self.cog.stats.callback(self.cog, interaction, user=member, limit=10) + await self.cog.stats.callback(self.cog, interaction, user=member) interaction.followup.send.assert_awaited() sent_embed = interaction.followup.send.call_args.kwargs.get("embed") @@ -95,7 +64,7 @@ async def user_with_messages(self) -> None: mock_guild.get_total_messages = AsyncMock(return_value=200) with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): - await self.cog.stats.callback(self.cog, interaction, user=member, limit=10) + await self.cog.stats.callback(self.cog, interaction, user=member) interaction.followup.send.assert_awaited() sent_embed = interaction.followup.send.call_args.kwargs.get("embed") @@ -105,6 +74,41 @@ async def user_with_messages(self) -> None: assert "Rank" in field_names, "Expected 'Rank' field in embed" assert "Percentage of Total Messages" in field_names, "Expected 'Percentage of Total Messages' field in embed" +class Leaderboard(TestingMessage): + def __init__(self): + super().__init__() + + @test + async def leaderboard_renders(self) -> None: + # mock member + member = MagicMock() + member.id = 42 + member.mention = "<@42>" + member.display_name = "TestUser" + + # mock interaction + interaction = MagicMock() + interaction.guild.id = 123 + interaction.guild.name = "Test Guild" + interaction.guild.get_member = MagicMock(return_value=member) + interaction.guild.fetch_member = AsyncMock(side_effect=Exception()) + interaction.followup.send = AsyncMock() + interaction.response.defer = AsyncMock() + interaction.user.display_name = "Tester" + + # mock guild + mock_guild = MagicMock() + mock_guild.id = 123 + mock_guild.get_top_senders = AsyncMock(return_value=[(1, 100), (2, 80), (3, 60)]) + mock_guild.get_total_messages = AsyncMock(return_value=240) + + with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): + await self.cog.leaderboard.callback(self.cog, interaction, limit=3) + + interaction.followup.send.assert_awaited() + sent_embed = interaction.followup.send.call_args.kwargs.get("embed") + assert sent_embed is not None, "Expected an embed to be sent" + @test async def leaderboard_no_messages(self) -> None: # mock interaction @@ -121,7 +125,7 @@ async def leaderboard_no_messages(self) -> None: mock_guild.get_total_messages = AsyncMock(return_value=0) with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): - await self.cog.stats.callback(self.cog, interaction, user=None, limit=10) + await self.cog.leaderboard.callback(self.cog, interaction, limit=10) interaction.followup.send.assert_awaited() sent_embed = interaction.followup.send.call_args.kwargs.get("embed") From b760db6216ddb45e533f6975056d9c5b6dfd3782 Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 18:36:10 +0000 Subject: [PATCH 09/11] added opt-in messages to commands --- killua/cogs/message.py | 19 +++++++++++++++---- killua/utils/classes/guild.py | 4 ++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/killua/cogs/message.py b/killua/cogs/message.py index aaeb28ffc..dca5f8a9f 100644 --- a/killua/cogs/message.py +++ b/killua/cogs/message.py @@ -30,11 +30,19 @@ async def stats( await interaction.response.defer() guild = await Guild.new(interaction.guild.id) + if not guild.message_tracking_enabled: + await interaction.followup.send("❌ Message tracking is disabled for this server.", ephemeral=True) + return - if user: - await self._show_user_stats(interaction, guild, user) - else: - await self._show_user_stats(interaction, guild, interaction.user) + if not user: + user = interaction.user + + member = await User.new(user.id) + if not member.message_tracking_enabled: + await interaction.followup.send(f"❌ {user.mention} has disabled message tracking for their account.", ephemeral=True) + return + + await self._show_user_stats(interaction, guild, user) @app_commands.command(name="leaderboard", description="Show the message leaderboard for this guild") @app_commands.describe( @@ -51,6 +59,9 @@ async def leaderboard( await interaction.response.defer() guild = await Guild.new(interaction.guild.id) + if not guild.message_tracking_enabled: + await interaction.followup.send("❌ Message tracking is disabled for this server.", ephemeral=True) + return await self._show_leaderboard(interaction, guild, limit) @app_commands.command(name="server_tracking", description="Toggle message tracking for this server") diff --git a/killua/utils/classes/guild.py b/killua/utils/classes/guild.py index fe18a05fe..646371ea6 100644 --- a/killua/utils/classes/guild.py +++ b/killua/utils/classes/guild.py @@ -157,8 +157,8 @@ def get_message_count(self, user_id: int) -> int: async def get_top_senders(self, limit: int = 10) -> List[tuple[int, int]]: """Fetch the top message senders in the guild. Returns a list of (user_id, message_count) tuples""" - # Sort the in-memory message_stats and return top N - sorted_stats = sorted(self.message_stats.items(), key=lambda x: x[1], reverse=True) + tracked_stats = {k: v for k, v in self.message_stats.items() if (await User.new(k)).message_tracking_enabled} + sorted_stats = sorted(tracked_stats.items(), key=lambda x: x[1], reverse=True) return sorted_stats[:limit] async def get_total_messages(self) -> int: From f270d542ab358e9ad9a5edb358a25ad90eb0ecec Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 19:00:54 +0000 Subject: [PATCH 10/11] delete user and guild message data on opt-out + confirmation view for data deletion --- killua/cogs/message.py | 65 ++++++++++++++++++++++++++++++----- killua/utils/classes/guild.py | 12 +++++++ killua/utils/classes/user.py | 21 ++++++++++- killua/utils/views.py | 31 +++++++++++++++++ 4 files changed, 120 insertions(+), 9 deletions(-) create mode 100644 killua/utils/views.py diff --git a/killua/cogs/message.py b/killua/cogs/message.py index dca5f8a9f..4f8b7b3a9 100644 --- a/killua/cogs/message.py +++ b/killua/cogs/message.py @@ -6,6 +6,7 @@ from killua.bot import BaseBot from killua.utils.classes.guild import Guild from killua.utils.classes.user import User +from killua.utils.views import ConfirmView TRACKING_SINCE = "2025-12-21" # The date message tracking was added MUST BE UPDATED IF DEPLOYED @@ -75,11 +76,35 @@ async def server_tracking( await interaction.response.defer(ephemeral=True) guild = await Guild.new(interaction.guild.id) + + if guild.message_tracking_enabled: + embed = discord.Embed( + title="⚠️ Warning", + description="Disabling message tracking will remove all message counts from this server's stats and leaderboards. Are you sure you want to proceed?", + color=discord.Color.orange() + ) + view = ConfirmView(interaction.user.id) + await interaction.followup.send(embed=embed, view=view, ephemeral=True) + await view.wait() + + if not view.value: + return # cancelled + new_status = await guild.toggle_message_tracking() + if new_status: + embed = discord.Embed( + title="✅ Message Tracking Enabled", + description="Message tracking has been enabled for this server. Future messages from users who have enabled message tracking will be counted in stats and leaderboards.", + color=discord.Color.green() + ) + else: + embed = discord.Embed( + title="❌ Message Tracking Disabled", + description="Message tracking has been disabled for this server. All message counts have been removed from stats and leaderboards.", + color=discord.Color.red() + ) - status_text = "enabled" if new_status else "disabled" - status_emoji = "✅" if new_status else "❌" - await interaction.followup.send(f"{status_emoji} Message tracking has been {status_text} for this server.", ephemeral=True) + await interaction.followup.send(embed=embed, ephemeral=True) @app_commands.command(name="user_tracking", description="Toggle message tracking for your account") @app_commands.checks.cooldown(1, 10.0) @@ -88,14 +113,37 @@ async def user_tracking( interaction: discord.Interaction, ): """Toggle message tracking for your account""" - await interaction.response.defer(ephemeral=True) - user = await User.new(interaction.user.id) + + if user.message_tracking_enabled: + embed = discord.Embed( + title="⚠️ Warning", + description="Disabling message tracking will remove your message counts from all guild leaderboards and stats. Are you sure you want to proceed?", + color=discord.Color.orange() + ) + view = ConfirmView(interaction.user.id) + await interaction.response.send_message(embed=embed, view=view, ephemeral=True) + await view.wait() + + if not view.value: + return # cancelled + new_status = await user.toggle_message_tracking() - status_text = "enabled" if new_status else "disabled" - status_emoji = "✅" if new_status else "❌" - await interaction.followup.send(f"{status_emoji} Message tracking has been {status_text} for your account.", ephemeral=True) + if new_status: + embed = discord.Embed( + title="✅ Message Tracking Enabled", + description="You have enabled message tracking for your account. Your future messages will be counted in guild stats and leaderboards.", + color=discord.Color.green() + ) + else: + embed = discord.Embed( + title="❌ Message Tracking Disabled", + description="You have disabled message tracking for your account. Your message counts have been removed from all guild stats and leaderboards.", + color=discord.Color.red() + ) + + await interaction.followup.send(embed=embed, ephemeral=True) async def _show_user_stats( self, @@ -198,4 +246,5 @@ async def stats_error( ) else: raise error + Cog = Message \ No newline at end of file diff --git a/killua/utils/classes/guild.py b/killua/utils/classes/guild.py index 646371ea6..146470462 100644 --- a/killua/utils/classes/guild.py +++ b/killua/utils/classes/guild.py @@ -24,6 +24,7 @@ class Guild: added_on: Optional[datetime] = None message_stats: Dict[int, int] = field(default_factory=dict) # user_id: message_count message_tracking_enabled: bool = False + tracking_since: Optional[datetime] = None cache: ClassVar[Dict[int, Guild]] = {} @classmethod @@ -73,6 +74,7 @@ async def new(cls, guild_id: int, member_count: Optional[int] = None) -> Guild: message_stats_raw = raw.get("message_stats", {}) raw["message_stats"] = {int(k): v for k, v in message_stats_raw.items()} if message_stats_raw else {} raw["message_tracking_enabled"] = raw.get("message_tracking_enabled", False) + raw["tracking_since"] = raw.get("tracking_since", None) guild = cls.from_dict(raw) cls.cache[guild_id] = guild @@ -179,4 +181,14 @@ async def toggle_message_tracking(self) -> bool: """Toggles message tracking for the guild""" self.message_tracking_enabled = not self.message_tracking_enabled await self._update_val("message_tracking_enabled", self.message_tracking_enabled) + if self.message_tracking_enabled: + # Remove all users who have disabled tracking from message_stats + tracked_stats = {k: v for k, v in self.message_stats.items() if (await User.new(k)).message_tracking_enabled} + self.message_stats = tracked_stats + await self._update_val("message_stats", {str(k): v for k, v in tracked_stats.items()}) + self.tracking_since = datetime.now() + await self._update_val("tracking_since", self.tracking_since) + else: + self.message_stats = {} + await self._update_val("message_stats", {}) return self.message_tracking_enabled \ No newline at end of file diff --git a/killua/utils/classes/user.py b/killua/utils/classes/user.py index 92c1aaefd..5f6a15e12 100644 --- a/killua/utils/classes/user.py +++ b/killua/utils/classes/user.py @@ -852,4 +852,23 @@ async def toggle_message_tracking(self) -> bool: """Toggles message tracking for the user""" self.message_tracking_enabled = not self.message_tracking_enabled await self._update_val("message_tracking_enabled", self.message_tracking_enabled) - return self.message_tracking_enabled \ No newline at end of file + + if not self.message_tracking_enabled: + await self._delete_message_tracking_data() + + return self.message_tracking_enabled + + async def _delete_message_tracking_data(self) -> None: + """Deletes all message tracking data for the user""" + from killua.utils.classes.guild import Guild + + result = await DB.guilds.update_many( + {f"message_stats.{self.id}": {"$exists": True}}, + {"$unset": {f"message_stats.{self.id}": ""}}, + ) + + for guild_id, guild in Guild.cache.items(): + if self.id in guild.message_stats: + del guild.message_stats[self.id] + + logging.info(f"Deleted message tracking data for user {self.id} in {result.modified_count} guild(s)") \ No newline at end of file diff --git a/killua/utils/views.py b/killua/utils/views.py new file mode 100644 index 000000000..9cbb8319e --- /dev/null +++ b/killua/utils/views.py @@ -0,0 +1,31 @@ +import discord +from discord import app_commands +from discord.ext import commands +from discord.ui import View, Button +from typing import Optional + +class ConfirmView(View): + """Simple confirmation view with Yes/No buttons""" + + def __init__(self, user_id: int): + super().__init__(timeout=30) + self.user_id = user_id + self.value = None + + @discord.ui.button(label="Confirm", style=discord.ButtonStyle.danger) + async def confirm(self, interaction: discord.Interaction, button: Button): + if interaction.user.id != self.user_id: + await interaction.response.send_message("You cannot interact with this confirmation.", ephemeral=True) + return + self.value = True + self.stop() + await interaction.response.edit_message(content="✅ Confirmed.", view=None) + + @discord.ui.button(label="Cancel", style=discord.ButtonStyle.secondary) + async def cancel(self, interaction: discord.Interaction, button: Button): + if interaction.user.id != self.user_id: + await interaction.response.send_message("You cannot interact with this confirmation.", ephemeral=True) + return + self.value = False + self.stop() + await interaction.response.edit_message(content="❌ Cancelled.", view=None) \ No newline at end of file From 221192f3a65fa384beec041e7a88e001002469fb Mon Sep 17 00:00:00 2001 From: daniithompson Date: Sun, 21 Dec 2025 21:36:55 +0000 Subject: [PATCH 11/11] changed tests to use test_db --- killua/tests/groups/message.py | 257 +++++++++++++++++++----------- killua/tests/testing.py | 6 + killua/tests/types/interaction.py | 21 +++ killua/tests/types/member.py | 1 + killua/utils/test_db.py | 222 +++++++++++++++----------- 5 files changed, 315 insertions(+), 192 deletions(-) diff --git a/killua/tests/groups/message.py b/killua/tests/groups/message.py index e72c65c5b..87be1b31b 100644 --- a/killua/tests/groups/message.py +++ b/killua/tests/groups/message.py @@ -3,6 +3,7 @@ from ...utils.classes import * from ..testing import Testing, test from ...cogs.message import Message +from ...utils.test_db import TestingDatabase class TestingMessage(Testing): def __init__(self): @@ -14,120 +15,182 @@ def __init__(self): @test async def user_with_no_messages(self) -> None: - # mock interaction - interaction = MagicMock() - interaction.guild.id = 123 - interaction.user.display_name = "Tester" - interaction.followup.send = AsyncMock() - interaction.response.defer = AsyncMock() - - # mock member with no messages - member = MagicMock() - member.id = 42 - member.mention = "<@42>" - member.display_name = "TestUser" - - # mock guild with no messages for this user - mock_guild = MagicMock() - mock_guild.id = 123 - mock_guild.get_message_count = MagicMock(return_value=0) - mock_guild.get_user_rank = AsyncMock(return_value=None) - mock_guild.get_total_messages = AsyncMock(return_value=0) - - with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): - await self.cog.stats.callback(self.cog, interaction, user=member) - - interaction.followup.send.assert_awaited() - sent_embed = interaction.followup.send.call_args.kwargs.get("embed") - assert sent_embed is not None and "No Messages" in sent_embed.fields[0].name, "Expected 'No Messages' field in embed" + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() # enable tracking + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + await user.toggle_message_tracking() # enable tracking + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.stats.callback(self.cog, interaction, user=member) + + assert interaction.followup.sent[0]["embeds"][0] is not None, "Expected an embed to be sent" + field_names = [field.name for field in interaction.followup.sent[0]["embeds"][0].fields] + assert "No Messages" in field_names or any("no" in name.lower() for name in field_names), "Expected 'No Messages' field in embed" @test async def user_with_messages(self) -> None: - # mock interaction - interaction = MagicMock() - interaction.guild.id = 123 - interaction.user.display_name = "Tester" - interaction.followup.send = AsyncMock() - interaction.response.defer = AsyncMock() - - # mock member with messages - member = MagicMock() - member.id = 42 - member.mention = "<@42>" - member.display_name = "TestUser" - - # mock guild with messages for this user - mock_guild = MagicMock() - mock_guild.id = 123 - mock_guild.get_message_count = MagicMock(return_value=50) - mock_guild.get_user_rank = AsyncMock(return_value=5) - mock_guild.get_total_messages = AsyncMock(return_value=200) - - with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): - await self.cog.stats.callback(self.cog, interaction, user=member) - - interaction.followup.send.assert_awaited() - sent_embed = interaction.followup.send.call_args.kwargs.get("embed") + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() # enable tracking + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + await user.toggle_message_tracking() # enable tracking + + guild.message_stats[member.id] = 50 + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.stats.callback(self.cog, interaction, user=member) + + sent_embed = interaction.followup.sent[0]["embeds"][0] assert sent_embed is not None, "Expected an embed to be sent" field_names = [field.name for field in sent_embed.fields] assert "Messages Sent" in field_names, "Expected 'Messages Sent' field in embed" assert "Rank" in field_names, "Expected 'Rank' field in embed" assert "Percentage of Total Messages" in field_names, "Expected 'Percentage of Total Messages' field in embed" + @test + async def user_opted_out(self) -> None: + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() # enable tracking + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.stats.callback(self.cog, interaction, user=member) + + + message = interaction.followup.sent[0]["content"] + assert "disabled" in message.lower(), "Expected tracking disabled message" + class Leaderboard(TestingMessage): def __init__(self): super().__init__() @test async def leaderboard_renders(self) -> None: - # mock member - member = MagicMock() - member.id = 42 - member.mention = "<@42>" - member.display_name = "TestUser" - - # mock interaction - interaction = MagicMock() - interaction.guild.id = 123 - interaction.guild.name = "Test Guild" - interaction.guild.get_member = MagicMock(return_value=member) - interaction.guild.fetch_member = AsyncMock(side_effect=Exception()) - interaction.followup.send = AsyncMock() - interaction.response.defer = AsyncMock() - interaction.user.display_name = "Tester" - - # mock guild - mock_guild = MagicMock() - mock_guild.id = 123 - mock_guild.get_top_senders = AsyncMock(return_value=[(1, 100), (2, 80), (3, 60)]) - mock_guild.get_total_messages = AsyncMock(return_value=240) - - with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): - await self.cog.leaderboard.callback(self.cog, interaction, limit=3) - - interaction.followup.send.assert_awaited() - sent_embed = interaction.followup.send.call_args.kwargs.get("embed") + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() # enable tracking + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + await user.toggle_message_tracking() # enable tracking + + guild.message_stats = {1: 100, 2: 80, 3: 60} + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.leaderboard.callback(self.cog, interaction, limit=3) + + sent_embed = interaction.followup.sent[0]["embeds"][0] assert sent_embed is not None, "Expected an embed to be sent" @test async def leaderboard_no_messages(self) -> None: - # mock interaction - interaction = MagicMock() - interaction.guild.id = 123 - interaction.user.display_name = "Tester" - interaction.followup.send = AsyncMock() - interaction.response.defer = AsyncMock() - - # mock guild with no messages - mock_guild = MagicMock() - mock_guild.id = 123 - mock_guild.get_top_senders = AsyncMock(return_value=[]) - mock_guild.get_total_messages = AsyncMock(return_value=0) - - with patch("killua.utils.classes.guild.Guild.new", AsyncMock(return_value=mock_guild)): - await self.cog.leaderboard.callback(self.cog, interaction, limit=10) - - interaction.followup.send.assert_awaited() - sent_embed = interaction.followup.send.call_args.kwargs.get("embed") + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() # enable tracking + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + await user.toggle_message_tracking() # enable tracking + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.leaderboard.callback(self.cog, interaction, limit=3) + + sent_embed = interaction.followup.sent[0]["embeds"][0] assert sent_embed is not None, "Expected an embed to be sent" - assert "No message data available" in sent_embed.fields[0].value, "Expected 'No message data available' message in embed" \ No newline at end of file + assert "No message data available" in sent_embed.fields[0].value, "Expected 'No message data available' message in embed" + + @test + async def leaderboard_guild_opted_out(self) -> None: + guild = await Guild.new(self.base_guild.id) + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + await user.toggle_message_tracking() # enable tracking + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.leaderboard.callback(self.cog, interaction, limit=3) + + message = interaction.followup.sent[0]["content"] + assert "disabled" in message.lower(), "Expected tracking disabled message" +class User_Tracking(TestingMessage): + def __init__(self): + super().__init__() + + @test + async def toggle_tracking_enable(self) -> None: + user = await User.new(self.base_author.id) + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + interaction.user = self.base_author + + await self.cog.user_tracking.callback(self.cog, interaction) + + assert interaction.response.is_done() or len(interaction.followup.sent) > 0, "Expected a response" + + user_after = await User.new(self.base_author.id) + assert user_after.message_tracking_enabled == True, "Expected tracking to be enabled" + + @test + async def toggle_tracking_disable_with_confirmation(self) -> None: + user = await User.new(self.base_author.id) + await user.toggle_message_tracking() # enable tracking + + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() + guild.message_stats[self.base_author.id] = 50 + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + interaction.user = self.base_author + + # mock the confirm view acceptance + with patch("killua.cogs.message.ConfirmView") as MockView: + mock_view_instance = MagicMock() + mock_view_instance.value = True + mock_view_instance.wait = AsyncMock() + MockView.return_value = mock_view_instance + + await self.cog.user_tracking.callback(self.cog, interaction) + + # confirmation view response + assert interaction.response.is_done(), "Expected confirmation prompt" + + # after confirmation, followup is sent + assert len(interaction.followup.sent) > 0, "Expected followup message after confirmation" + + @test + async def toggle_tracking_disable_cancelled(self) -> None: + user = await User.new(self.base_author.id) + await user.toggle_message_tracking() # enable tracking + + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() + guild.message_stats[self.base_author.id] = 50 + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + interaction.user = self.base_author + + # mock the confirm view cancellation + with patch("killua.cogs.message.ConfirmView") as MockView: + mock_view_instance = MagicMock() + mock_view_instance.value = False # User cancels + mock_view_instance.wait = AsyncMock() + MockView.return_value = mock_view_instance + + await self.cog.user_tracking.callback(self.cog, interaction) + + # Should show confirmation view + assert interaction.response.is_done(), "Expected confirmation prompt" + + # tracking should still be enabled after cancellation + user_after = await User.new(self.base_author.id) + assert user_after.message_tracking_enabled == True, "Expected tracking to still be enabled after cancel" + assert guild.get_message_count(self.base_author.id) == 50, "Expected message count to remain unchanged after cancel" \ No newline at end of file diff --git a/killua/tests/testing.py b/killua/tests/testing.py index 3a8f9ee7d..3f058c018 100644 --- a/killua/tests/testing.py +++ b/killua/tests/testing.py @@ -125,6 +125,12 @@ def __init__(self, method): async def __call__(self, obj: Testing, *args, **kwargs): from .types import Result, ResultData + from ..utils.test_db import TestingDatabase + from ..utils.classes import Guild, User + + TestingDatabase.db.clear() + Guild.cache.clear() + User.cache.clear() try: logging.debug( diff --git a/killua/tests/types/interaction.py b/killua/tests/types/interaction.py index d50b7890b..37328b49a 100644 --- a/killua/tests/types/interaction.py +++ b/killua/tests/types/interaction.py @@ -39,6 +39,26 @@ async def send_modal(self, *args, **kwargs) -> None: def is_done(self) -> bool: return self._is_done +class ArgumentFollowup: + def __init__(self, interaction: ArgumentInteraction): + self.interaction = interaction + self.sent = [] # store sent followups + + async def send(self, content = None, /, *, embed=None, embeds=None, ephemeral: bool = False, **kwargs): + # Normalize embeds + if embed is not None: + embeds = [embed] + elif embeds is None: + embeds = [] + + self.sent.append( + { + "content": content, + "embeds": embeds, + "ephemeral": ephemeral, + "kwargs": kwargs, + } + ) class ArgumentInteraction: """This classes purpose is purely to be supplied to callbacks of message interactions""" @@ -48,6 +68,7 @@ def __init__(self, context: Context, **kwargs): self.context = context self.user = context.author self.response = ArgumentResponseInteraction(self) + self.followup = ArgumentFollowup(self) class TestingInteraction(Interaction): diff --git a/killua/tests/types/member.py b/killua/tests/types/member.py index 856f0161f..04c514bd9 100644 --- a/killua/tests/types/member.py +++ b/killua/tests/types/member.py @@ -26,6 +26,7 @@ def __init__(self, **kwargs): "communication_disabled_until", "" ) self.premium_since: Union[datetime, None] = kwargs.pop("premium_since", None) + self.display_avatar = self.avatar @property def display_name(self) -> str: diff --git a/killua/utils/test_db.py b/killua/utils/test_db.py index 952a6ecb1..e354b474a 100644 --- a/killua/utils/test_db.py +++ b/killua/utils/test_db.py @@ -1,5 +1,5 @@ from typing import Optional, List, Dict - +from random import randint class TestingDatabase: """A database class imitating pymongos collection classes""" @@ -15,129 +15,161 @@ def collection(self) -> str: self.db[self._collection] = [] return self._collection - # def _random_id(self) -> int: - # """Creates a random 8 digit number""" - # res = int(str(randint(0, 99999999)).zfill(8)) - # if res in [x["_id"] for x in self.db[self.collection]]: - # return self._random_id() - # else: - # return res + def _random_id(self) -> int: + """Creates a random 8 digit number""" + res = int(str(randint(0, 99999999)).zfill(8)) + ids = [x.get("_id") for x in self.db.get(self.collection, [])] + if res in ids: + return self._random_id() + return res def _normalize_dict(self, dictionary: dict) -> dict: - """Changes the {one.two: } to {one: {two: }}""" + """Changes the {one.two: } to {one: {two: }} for $set/$inc shape""" for _, d in dictionary.items(): if isinstance(d, dict): - for key, val in d.items(): + for key, val in list(d.items()): if "." in key: - k1 = key.split(".")[0] - k2 = key.split(".")[1] + k1, k2 = key.split(".", 1) + d.setdefault(k1, {}) d[k1][k2] = val del d[key] return dictionary + def _matches(self, doc: dict, where: dict) -> bool: + if not where: + return True + for k, v in where.items(): + if k not in doc: + return False + if isinstance(v, dict) and "$in" in v: + if doc[k] not in v["$in"]: + return False + else: + if doc[k] != v: + return False + return True + async def find_one(self, where: dict) -> Optional[dict]: coll = self.db[self.collection] for d in coll: - for key, value in d.items(): - if len([k for k, v in where.items() if k == key and v == value]) == len( - where - ): # When all conditions defined in "where" are met - return d + if self._matches(d, where): + return d async def find(self, where: dict) -> Optional[list]: coll = self.db[self.collection] - results = [] - - for d in coll: - for key, value in d.items(): - if [ - x - for x in list(where.values()) - if isinstance(x, dict) and "$in" in x.keys() - ]: - for k, v in [ - (k, v) - for k, v in list(where.items()) - if isinstance(v, dict) and "$in" in v.keys() - ]: - if k == key and value in v["$in"]: - results.append(d) - - elif len( - [k for k, v in where.items() if k == key and v == value] - ) == len( - where - ): # When all conditions defined in "where" are met - results.append(d) - - return results + return [d for d in coll if self._matches(d, where)] async def insert_one(self, object: dict) -> None: - self.db[self.collection].append(object) + obj = dict(object) # copy + if "_id" not in obj: + obj["_id"] = self._random_id() + self.db[self.collection].append(obj) async def insert_many(self, objects: List[dict]) -> None: for obj in objects: await self.insert_one(obj) - async def update_one(self, where: dict, update: Dict[str, dict]) -> dict: - # updated = False - operator = list(update.keys())[0] # This does not support multiple keys - - for v in update.values(): # Making sure it is all in the right format - v = self._normalize_dict(v) # lgtm [py/multiple-definition] + def _apply_update(self, item: dict, update: Dict[str, dict]) -> None: + operator = list(update.keys())[0] + + def _set_by_path(target: dict, dotted_key: str, value): + parts = dotted_key.split(".") + cur = target + for i, part in enumerate(parts): + if i == len(parts) - 1: + cur[part] = value + else: + cur.setdefault(part, {}) + cur = cur[part] + + if operator == "$set": + # Do not try to pick a single subkey; set the whole value, + # and support dotted paths. + for k, val in update[operator].items(): + _set_by_path(item, k, val) + + elif operator == "$push": + for k, val in update[operator].items(): + parts = k.split(".") + cur = item + for i, part in enumerate(parts): + if i == len(parts) - 1: + cur.setdefault(part, []) + cur[part].append(val) + else: + cur.setdefault(part, {}) + cur = cur[part] + + elif operator == "$pull": + for k, val in update[operator].items(): + parts = k.split(".") + cur = item + for i, part in enumerate(parts): + if i == len(parts) - 1: + if part in cur and isinstance(cur[part], list): + try: + cur[part].remove(val) + except ValueError: + pass + else: + cur = cur.get(part, {}) + if not isinstance(cur, dict): + break + + elif operator == "$inc": + for k, val in update[operator].items(): + parts = k.split(".") + cur = item + for i, part in enumerate(parts): + if i == len(parts) - 1: + cur[part] = (cur.get(part, 0) or 0) + val + else: + cur.setdefault(part, {}) + cur = cur[part] + + elif operator == "$unset": + for dotted_key, _ in update[operator].items(): + parts = dotted_key.split(".") + target = item + for i, part in enumerate(parts): + if i == len(parts) - 1: + if isinstance(target, dict) and part in target: + del target[part] + else: + target = target.get(part, {}) + if not isinstance(target, dict): + break + async def update_one(self, where: dict, update: Dict[str, dict]) -> dict: for p, item in enumerate(self.db[self.collection]): - for key, value in item.items(): - if len([k for k, v in where.items() if key == k and value == v]) == len( - where - ): - if operator == "$set": - for k, val in update[operator].items(): - if isinstance(val, dict): - self.db[self.collection][p][k][list(val.keys())[0]] = ( - list(val.values())[0] - ) - else: - self.db[self.collection][p][k] = val - if operator == "$push": - for k, val in update[operator].items(): - if isinstance(val, dict): - self.db[self.collection][p][k][ - list(val.keys())[0] - ].append(list(val.values())[0]) - else: - self.db[self.collection][p][k].append(val) - if operator == "$pull": - for k, val in update[operator].items(): - if isinstance(val, dict): - self.db[self.collection][p][k][ - list(val.keys())[0] - ].remove(list(val.values())[0]) - else: - self.db[self.collection][p][k].remove(val) - elif operator == "$inc": - for k, val in update[operator].items(): - if isinstance(val, dict): - self.db[self.collection][p][k][ - list(val.keys())[0] - ] += list(val.values())[0] - else: - self.db[self.collection][p][k] += val - # updated = True - - # if not updated: - # self.insert_one(update) - - return update # I only need this when the update would equal the object + if self._matches(item, where): + self._apply_update(self.db[self.collection][p], update) + return update + return update async def count_documents(self, where: dict = {}) -> int: return len(await self.find(where) or []) async def delete_one(self, where: dict) -> None: - ... # TODO: Implement this + coll = self.db[self.collection] + for i, d in enumerate(coll): + if self._matches(d, where): + del coll[i] + return async def delete_many(self, where: dict) -> None: - ... # TODO: Implement this + coll = self.db[self.collection] + self.db[self.collection] = [d for d in coll if not self._matches(d, where)] - async def update_many(self, where: dict, update: dict) -> None: - ... # TODO: Implement this + async def update_many(self, where: dict, update: dict) -> dict: + modified_count = 0 + for p, item in enumerate(self.db[self.collection]): + if self._matches(item, where): + self._apply_update(self.db[self.collection][p], update) + modified_count += 1 + + class UpdateManyResult: + def __init__(self, modified_count: int): + self.modified_count = modified_count + + return UpdateManyResult(modified_count) \ No newline at end of file