diff --git a/killua/cogs/events.py b/killua/cogs/events.py index a7f23f153..1e3b93f72 100644 --- a/killua/cogs/events.py +++ b/killua/cogs/events.py @@ -1140,5 +1140,31 @@ async def on_command_error(self, ctx: commands.Context, error): pass # This theoretically should be covered by all the cases above, # but handling it again here can't hurt + @commands.Cog.listener() + async def on_message(self,message: discord.Message): + # ignore bot messages + if message.author.bot: + return + + # do not track DMs + if not message.guild: + return + + # ignore system messages + if message.type != discord.MessageType.default: + return + + try: + guild = await Guild.new(message.guild.id) + if not guild.message_tracking_enabled: + return # guild has opted out + + user = await User.new(message.author.id) + if not user.message_tracking_enabled: + return # user has opted out + + await guild.increment_message_count(message.author.id) + except Exception as e: + logging.error(f"Failed to increment message count for user {message.author.id} in guild {message.guild.id}: {e}") Cog = Events diff --git a/killua/cogs/message.py b/killua/cogs/message.py new file mode 100644 index 000000000..4f8b7b3a9 --- /dev/null +++ b/killua/cogs/message.py @@ -0,0 +1,250 @@ +import discord +from discord import app_commands +from discord.ext import commands +from typing import Optional + +from killua.bot import BaseBot +from killua.utils.classes.guild import Guild +from killua.utils.classes.user import User +from killua.utils.views import ConfirmView + +TRACKING_SINCE = "2025-12-21" # The date message tracking was added MUST BE UPDATED IF DEPLOYED + +class Message(commands.GroupCog, group_name="message"): + """Cog to handle stats commands""" + + def __init__(self, client: BaseBot): + self.client = client + + @app_commands.command(name="stats", description="Get message stats for a user in this guild") + @app_commands.describe( + user="View stats for a specific user", + ) + @app_commands.guild_only() + @app_commands.checks.cooldown(1, 5.0) + async def stats( + self, + interaction: discord.Interaction, + user: discord.Member = None, + ): + """View message stats for a user or top users in the guild.""" + await interaction.response.defer() + + guild = await Guild.new(interaction.guild.id) + if not guild.message_tracking_enabled: + await interaction.followup.send("❌ Message tracking is disabled for this server.", ephemeral=True) + return + + if not user: + user = interaction.user + + member = await User.new(user.id) + if not member.message_tracking_enabled: + await interaction.followup.send(f"❌ {user.mention} has disabled message tracking for their account.", ephemeral=True) + return + + await self._show_user_stats(interaction, guild, user) + + @app_commands.command(name="leaderboard", description="Show the message leaderboard for this guild") + @app_commands.describe( + limit="Number of top users to display (max 25)", + ) + @app_commands.guild_only() + @app_commands.checks.cooldown(1, 10.0) + async def leaderboard( + self, + interaction: discord.Interaction, + limit: int = 10, + ): + """Display the message leaderboard for this guild.""" + await interaction.response.defer() + + guild = await Guild.new(interaction.guild.id) + if not guild.message_tracking_enabled: + await interaction.followup.send("❌ Message tracking is disabled for this server.", ephemeral=True) + return + await self._show_leaderboard(interaction, guild, limit) + + @app_commands.command(name="server_tracking", description="Toggle message tracking for this server") + @app_commands.guild_only() + @app_commands.default_permissions(manage_guild=True) + async def server_tracking( + self, + interaction: discord.Interaction, + ): + """Toggle message tracking for this server""" + await interaction.response.defer(ephemeral=True) + + guild = await Guild.new(interaction.guild.id) + + if guild.message_tracking_enabled: + embed = discord.Embed( + title="⚠️ Warning", + description="Disabling message tracking will remove all message counts from this server's stats and leaderboards. Are you sure you want to proceed?", + color=discord.Color.orange() + ) + view = ConfirmView(interaction.user.id) + await interaction.followup.send(embed=embed, view=view, ephemeral=True) + await view.wait() + + if not view.value: + return # cancelled + + new_status = await guild.toggle_message_tracking() + if new_status: + embed = discord.Embed( + title="✅ Message Tracking Enabled", + description="Message tracking has been enabled for this server. Future messages from users who have enabled message tracking will be counted in stats and leaderboards.", + color=discord.Color.green() + ) + else: + embed = discord.Embed( + title="❌ Message Tracking Disabled", + description="Message tracking has been disabled for this server. All message counts have been removed from stats and leaderboards.", + color=discord.Color.red() + ) + + await interaction.followup.send(embed=embed, ephemeral=True) + + @app_commands.command(name="user_tracking", description="Toggle message tracking for your account") + @app_commands.checks.cooldown(1, 10.0) + async def user_tracking( + self, + interaction: discord.Interaction, + ): + """Toggle message tracking for your account""" + user = await User.new(interaction.user.id) + + if user.message_tracking_enabled: + embed = discord.Embed( + title="⚠️ Warning", + description="Disabling message tracking will remove your message counts from all guild leaderboards and stats. Are you sure you want to proceed?", + color=discord.Color.orange() + ) + view = ConfirmView(interaction.user.id) + await interaction.response.send_message(embed=embed, view=view, ephemeral=True) + await view.wait() + + if not view.value: + return # cancelled + + new_status = await user.toggle_message_tracking() + + if new_status: + embed = discord.Embed( + title="✅ Message Tracking Enabled", + description="You have enabled message tracking for your account. Your future messages will be counted in guild stats and leaderboards.", + color=discord.Color.green() + ) + else: + embed = discord.Embed( + title="❌ Message Tracking Disabled", + description="You have disabled message tracking for your account. Your message counts have been removed from all guild stats and leaderboards.", + color=discord.Color.red() + ) + + await interaction.followup.send(embed=embed, ephemeral=True) + + async def _show_user_stats( + self, + interaction: discord.Interaction, + guild: Guild, + member: discord.Member + ): + """Display stats for a specific user""" + message_count = guild.get_message_count(member.id) + rank = await guild.get_user_rank(member.id) + total_messages = await guild.get_total_messages() + + embed = discord.Embed( + title="📊 Message Stats", + description=f"Stats for {member.mention} in **{interaction.guild.name}**", + color=discord.Color.blue() + ) + + if message_count == 0: + embed.add_field( + name="No Messages", + value=f"{member.mention} has not sent any messages in this guild.", + inline=False + ) + else: + percentage = (message_count / total_messages) * 100 if total_messages > 0 else 0 + + embed.add_field(name="Messages Sent", value=f"{message_count:,}", inline=True) + embed.add_field(name="Rank", value=f"#{rank}", inline=True) + embed.add_field(name="Percentage of Total Messages", value=f"{percentage:.2f}%", inline=True) + + embed.set_footer(text=f"Tracking since {TRACKING_SINCE} • Requested by {interaction.user.display_name}", icon_url=interaction.user.display_avatar.url) + await interaction.followup.send(embed=embed) + + async def _show_leaderboard( + self, + interaction: discord.Interaction, + guild: Guild, + limit: int + ): + """Display the message leaderboard for the guild""" + + if limit < 1 or limit > 25: + await interaction.followup.send("❌ Limit must be between 1 and 25.", ephemeral=True) + return + + top_senders = await guild.get_top_senders(limit) + total_messages = await guild.get_total_messages() + + embed = discord.Embed( + title="📊 Message Leaderboard", + description=f"Top {limit} message senders in **{interaction.guild.name}**", + color=discord.Color.blue() + ) + + if not top_senders: + embed.add_field( + name="No Data", + value="No message data available for this guild.", + inline=False + ) + else: + leaderboard = "" + for rank, (user_id, message_count) in enumerate(top_senders, start=1): + member = interaction.guild.get_member(user_id) + if member is None: + try: + member = await interaction.guild.fetch_member(user_id) + except (discord.NotFound, discord.HTTPException): + pass + member_name = member.display_name if member else f"User ID {user_id}" + percentage = (message_count / total_messages) * 100 if total_messages > 0 else 0 + + medal = f"#{rank}" + if rank == 1: + medal = "🥇 " + elif rank == 2: + medal = "🥈 " + elif rank == 3: + medal = "🥉 " + + leaderboard += f"**{medal}**- {member_name}: {message_count:,} messages ({percentage:.2f}%)\n" + + embed.add_field(name="Leaderboard", value=leaderboard, inline=False) + + embed.set_footer(text="Tracking since {TRACKING_SINCE} • Requested by " + interaction.user.display_name, icon_url=interaction.user.display_avatar.url) + await interaction.followup.send(embed=embed) + + @stats.error + async def stats_error( + self, + interaction: discord.Interaction, + error: app_commands.AppCommandError + ): + """Error handler for the stats command""" + if isinstance(error, app_commands.CommandOnCooldown): + await interaction.response.send_message( + f"⏳ This command is on cooldown. Please try again in {error.retry_after:.1f} seconds.", + ephemeral=True + ) + else: + raise error + +Cog = Message \ No newline at end of file diff --git a/killua/migrate.py b/killua/migrate.py index 981b7f13b..5fcd11f50 100644 --- a/killua/migrate.py +++ b/killua/migrate.py @@ -56,6 +56,25 @@ async def migrate(): logging.info("Migrated user achievements key to achievements successfully") + # Add message_stats field to all guilds + result = await DB.guilds.update_many( + {"message_stats": {"$exists": False}}, + {"$set": {"message_stats": {}}} + ) + logging.info(f"Added message_stats field to {result.modified_count} guilds") + + result = await DB.guilds.update_many( + {"message_tracking_enabled": {"$exists": False}}, + {"$set": {"message_tracking_enabled": False}} + ) + logging.info(f"Added message_tracking_enabled field to {result.modified_count} guilds") + + result = await DB.teams.update_many( + {"message_tracking_enabled": {"$exists": False}}, + {"$set": {"message_tracking_enabled": False}} + ) + logging.info(f"Added message_tracking_enabled field to {result.modified_count} users") + await DB.const.update_one( {"_id": "migrate"}, {"$set": {"value": True}}, diff --git a/killua/tests/groups/__init__.py b/killua/tests/groups/__init__.py index 800932907..bede26b06 100644 --- a/killua/tests/groups/__init__.py +++ b/killua/tests/groups/__init__.py @@ -1,7 +1,8 @@ from .actions import TestingActions from .cards import TestingCards from .dev import TestingDev +from .message import TestingMessage -tests = [TestingActions, TestingCards, TestingDev] +tests = [TestingActions, TestingCards, TestingDev, TestingMessage] __all__ = ["tests"] diff --git a/killua/tests/groups/message.py b/killua/tests/groups/message.py new file mode 100644 index 000000000..87be1b31b --- /dev/null +++ b/killua/tests/groups/message.py @@ -0,0 +1,196 @@ +from unittest.mock import AsyncMock, MagicMock, patch +from ..types import * +from ...utils.classes import * +from ..testing import Testing, test +from ...cogs.message import Message +from ...utils.test_db import TestingDatabase + +class TestingMessage(Testing): + def __init__(self): + super().__init__(cog=Message) + +class Stats(TestingMessage): + def __init__(self): + super().__init__() + + @test + async def user_with_no_messages(self) -> None: + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() # enable tracking + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + await user.toggle_message_tracking() # enable tracking + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.stats.callback(self.cog, interaction, user=member) + + assert interaction.followup.sent[0]["embeds"][0] is not None, "Expected an embed to be sent" + field_names = [field.name for field in interaction.followup.sent[0]["embeds"][0].fields] + assert "No Messages" in field_names or any("no" in name.lower() for name in field_names), "Expected 'No Messages' field in embed" + + @test + async def user_with_messages(self) -> None: + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() # enable tracking + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + await user.toggle_message_tracking() # enable tracking + + guild.message_stats[member.id] = 50 + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.stats.callback(self.cog, interaction, user=member) + + sent_embed = interaction.followup.sent[0]["embeds"][0] + assert sent_embed is not None, "Expected an embed to be sent" + field_names = [field.name for field in sent_embed.fields] + assert "Messages Sent" in field_names, "Expected 'Messages Sent' field in embed" + assert "Rank" in field_names, "Expected 'Rank' field in embed" + assert "Percentage of Total Messages" in field_names, "Expected 'Percentage of Total Messages' field in embed" + + @test + async def user_opted_out(self) -> None: + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() # enable tracking + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.stats.callback(self.cog, interaction, user=member) + + + message = interaction.followup.sent[0]["content"] + assert "disabled" in message.lower(), "Expected tracking disabled message" + +class Leaderboard(TestingMessage): + def __init__(self): + super().__init__() + + @test + async def leaderboard_renders(self) -> None: + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() # enable tracking + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + await user.toggle_message_tracking() # enable tracking + + guild.message_stats = {1: 100, 2: 80, 3: 60} + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.leaderboard.callback(self.cog, interaction, limit=3) + + sent_embed = interaction.followup.sent[0]["embeds"][0] + assert sent_embed is not None, "Expected an embed to be sent" + + @test + async def leaderboard_no_messages(self) -> None: + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() # enable tracking + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + await user.toggle_message_tracking() # enable tracking + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.leaderboard.callback(self.cog, interaction, limit=3) + + sent_embed = interaction.followup.sent[0]["embeds"][0] + assert sent_embed is not None, "Expected an embed to be sent" + assert "No message data available" in sent_embed.fields[0].value, "Expected 'No message data available' message in embed" + + @test + async def leaderboard_guild_opted_out(self) -> None: + guild = await Guild.new(self.base_guild.id) + + member = DiscordMember(guild=self.base_guild) + user = await User.new(member.id) + await user.toggle_message_tracking() # enable tracking + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + + await self.cog.leaderboard.callback(self.cog, interaction, limit=3) + + message = interaction.followup.sent[0]["content"] + assert "disabled" in message.lower(), "Expected tracking disabled message" +class User_Tracking(TestingMessage): + def __init__(self): + super().__init__() + + @test + async def toggle_tracking_enable(self) -> None: + user = await User.new(self.base_author.id) + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + interaction.user = self.base_author + + await self.cog.user_tracking.callback(self.cog, interaction) + + assert interaction.response.is_done() or len(interaction.followup.sent) > 0, "Expected a response" + + user_after = await User.new(self.base_author.id) + assert user_after.message_tracking_enabled == True, "Expected tracking to be enabled" + + @test + async def toggle_tracking_disable_with_confirmation(self) -> None: + user = await User.new(self.base_author.id) + await user.toggle_message_tracking() # enable tracking + + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() + guild.message_stats[self.base_author.id] = 50 + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + interaction.user = self.base_author + + # mock the confirm view acceptance + with patch("killua.cogs.message.ConfirmView") as MockView: + mock_view_instance = MagicMock() + mock_view_instance.value = True + mock_view_instance.wait = AsyncMock() + MockView.return_value = mock_view_instance + + await self.cog.user_tracking.callback(self.cog, interaction) + + # confirmation view response + assert interaction.response.is_done(), "Expected confirmation prompt" + + # after confirmation, followup is sent + assert len(interaction.followup.sent) > 0, "Expected followup message after confirmation" + + @test + async def toggle_tracking_disable_cancelled(self) -> None: + user = await User.new(self.base_author.id) + await user.toggle_message_tracking() # enable tracking + + guild = await Guild.new(self.base_guild.id) + await guild.toggle_message_tracking() + guild.message_stats[self.base_author.id] = 50 + + interaction = ArgumentInteraction(context=self.base_context, guild=self.base_guild) + interaction.user = self.base_author + + # mock the confirm view cancellation + with patch("killua.cogs.message.ConfirmView") as MockView: + mock_view_instance = MagicMock() + mock_view_instance.value = False # User cancels + mock_view_instance.wait = AsyncMock() + MockView.return_value = mock_view_instance + + await self.cog.user_tracking.callback(self.cog, interaction) + + # Should show confirmation view + assert interaction.response.is_done(), "Expected confirmation prompt" + + # tracking should still be enabled after cancellation + user_after = await User.new(self.base_author.id) + assert user_after.message_tracking_enabled == True, "Expected tracking to still be enabled after cancel" + assert guild.get_message_count(self.base_author.id) == 50, "Expected message count to remain unchanged after cancel" \ No newline at end of file diff --git a/killua/tests/testing.py b/killua/tests/testing.py index 09bd75d2e..3f058c018 100644 --- a/killua/tests/testing.py +++ b/killua/tests/testing.py @@ -50,12 +50,18 @@ def __init__(self, cog: Cog): def all_tests(self) -> List[Testing]: """Automatically checks what functions are test based on their name and the overlap with the Cog method names""" cog_methods = [] - for cmd in [(command.name, command) for command in self.cog.get_commands()]: - if hasattr(cmd[1], "walk_commands") and cmd[1].walk_commands(): - for child in cmd[1].walk_commands(): - cog_methods.append((child.name, child)) - else: - cog_methods.append(cmd) + + # handle app commands and normal commands + if hasattr(self.cog, '__cog_app_commands__'): + for cmd in self.cog.__cog_app_commands__: + cog_methods.append((cmd.name, cmd)) + else: + for cmd in [(command.name, command) for command in self.cog.get_commands()]: + if hasattr(cmd[1], "walk_commands") and cmd[1].walk_commands(): + for child in cmd[1].walk_commands(): + cog_methods.append((child.name, child)) + else: + cog_methods.append(cmd) command_classes: List[Testing] = [] @@ -119,6 +125,12 @@ def __init__(self, method): async def __call__(self, obj: Testing, *args, **kwargs): from .types import Result, ResultData + from ..utils.test_db import TestingDatabase + from ..utils.classes import Guild, User + + TestingDatabase.db.clear() + Guild.cache.clear() + User.cache.clear() try: logging.debug( diff --git a/killua/tests/types/interaction.py b/killua/tests/types/interaction.py index d50b7890b..37328b49a 100644 --- a/killua/tests/types/interaction.py +++ b/killua/tests/types/interaction.py @@ -39,6 +39,26 @@ async def send_modal(self, *args, **kwargs) -> None: def is_done(self) -> bool: return self._is_done +class ArgumentFollowup: + def __init__(self, interaction: ArgumentInteraction): + self.interaction = interaction + self.sent = [] # store sent followups + + async def send(self, content = None, /, *, embed=None, embeds=None, ephemeral: bool = False, **kwargs): + # Normalize embeds + if embed is not None: + embeds = [embed] + elif embeds is None: + embeds = [] + + self.sent.append( + { + "content": content, + "embeds": embeds, + "ephemeral": ephemeral, + "kwargs": kwargs, + } + ) class ArgumentInteraction: """This classes purpose is purely to be supplied to callbacks of message interactions""" @@ -48,6 +68,7 @@ def __init__(self, context: Context, **kwargs): self.context = context self.user = context.author self.response = ArgumentResponseInteraction(self) + self.followup = ArgumentFollowup(self) class TestingInteraction(Interaction): diff --git a/killua/tests/types/member.py b/killua/tests/types/member.py index 856f0161f..04c514bd9 100644 --- a/killua/tests/types/member.py +++ b/killua/tests/types/member.py @@ -26,6 +26,7 @@ def __init__(self, **kwargs): "communication_disabled_until", "" ) self.premium_since: Union[datetime, None] = kwargs.pop("premium_since", None) + self.display_avatar = self.avatar @property def display_name(self) -> str: diff --git a/killua/utils/classes/guild.py b/killua/utils/classes/guild.py index 599df6708..146470462 100644 --- a/killua/utils/classes/guild.py +++ b/killua/utils/classes/guild.py @@ -22,6 +22,9 @@ class Guild: polls: dict = field(default_factory=dict) tags: List[dict] = field(default_factory=list) added_on: Optional[datetime] = None + message_stats: Dict[int, int] = field(default_factory=dict) # user_id: message_count + message_tracking_enabled: bool = False + tracking_since: Optional[datetime] = None cache: ClassVar[Dict[int, Guild]] = {} @classmethod @@ -67,6 +70,12 @@ async def new(cls, guild_id: int, member_count: Optional[int] = None) -> Guild: raw["approximate_member_count"] = await cls._member_count_helper( guild_id, raw.get("approximate_member_count", None), member_count ) + # Convert message_stats string keys to int keys for in-memory use + message_stats_raw = raw.get("message_stats", {}) + raw["message_stats"] = {int(k): v for k, v in message_stats_raw.items()} if message_stats_raw else {} + raw["message_tracking_enabled"] = raw.get("message_tracking_enabled", False) + raw["tracking_since"] = raw.get("tracking_since", None) + guild = cls.from_dict(raw) cls.cache[guild_id] = guild @@ -80,7 +89,7 @@ def is_premium(self) -> bool: async def add_default(cls, guild_id: int, member_count: Optional[int]) -> None: """Adds a guild to the database""" await DB.guilds.insert_one( - {"id": guild_id, "points": 0, "items": "", "badges": [], "prefix": "k!", "approximate_member_count": member_count or 0, "added_on": datetime.now()} + {"id": guild_id, "points": 0, "items": "", "badges": [], "prefix": "k!", "approximate_member_count": member_count or 0, "added_on": datetime.now(), "message_stats": {}, "message_tracking_enabled": False} ) @classmethod @@ -136,3 +145,50 @@ async def update_poll_votes(self, id: int, updated: dict) -> None: """Updates the votes of a poll""" self.polls[str(id)]["votes"] = updated await self._update_val(f"polls.{id}.votes", updated) + + async def increment_message_count(self, user_id: int, amount: int = 1) -> None: + """Increments the message count for a user in this guild""" + user_id_str = str(user_id) + current_count = self.message_stats.get(user_id, 0) + self.message_stats[user_id] = current_count + amount + await self._update_val(f"message_stats.{user_id_str}", amount, "$inc") + + def get_message_count(self, user_id: int) -> int: + """Gets the message count for a user in this guild""" + return self.message_stats.get(user_id, 0) + + async def get_top_senders(self, limit: int = 10) -> List[tuple[int, int]]: + """Fetch the top message senders in the guild. Returns a list of (user_id, message_count) tuples""" + tracked_stats = {k: v for k, v in self.message_stats.items() if (await User.new(k)).message_tracking_enabled} + sorted_stats = sorted(tracked_stats.items(), key=lambda x: x[1], reverse=True) + return sorted_stats[:limit] + + async def get_total_messages(self) -> int: + """Gets the total number of messages sent in this guild""" + return sum(self.message_stats.values()) + + async def get_user_rank(self, user_id: int) -> Optional[int]: + """Gets the rank of a user in this guild based on message count""" + user_count = self.message_stats.get(user_id, 0) + + if user_count == 0: + return None + + rank = sum(1 for count in self.message_stats.values() if count > user_count) + return rank + 1 + + async def toggle_message_tracking(self) -> bool: + """Toggles message tracking for the guild""" + self.message_tracking_enabled = not self.message_tracking_enabled + await self._update_val("message_tracking_enabled", self.message_tracking_enabled) + if self.message_tracking_enabled: + # Remove all users who have disabled tracking from message_stats + tracked_stats = {k: v for k, v in self.message_stats.items() if (await User.new(k)).message_tracking_enabled} + self.message_stats = tracked_stats + await self._update_val("message_stats", {str(k): v for k, v in tracked_stats.items()}) + self.tracking_since = datetime.now() + await self._update_val("tracking_since", self.tracking_since) + else: + self.message_stats = {} + await self._update_val("message_stats", {}) + return self.message_tracking_enabled \ No newline at end of file diff --git a/killua/utils/classes/user.py b/killua/utils/classes/user.py index 4decbb1f9..5f6a15e12 100644 --- a/killua/utils/classes/user.py +++ b/killua/utils/classes/user.py @@ -2,7 +2,9 @@ from datetime import datetime, timedelta from typing import Any, ClassVar, Dict, List, Optional, Union, cast, Literal, Tuple -from dataclasses import dataclass +from dataclasses import dataclass\ + +import logging from killua.static.constants import ( DB, @@ -42,6 +44,7 @@ class User: has_user_installed: bool email: Optional[str] email_notifications: Dict[Literal["news", "updates", "posts"], bool] + message_tracking_enabled: bool = False cache: ClassVar[Dict[int, User]] = {} async def set_email(self, email: str) -> None: @@ -108,6 +111,7 @@ async def new(cls, user_id: int): "email_notifications", {"news": False, "updates": False, "posts": False}, ), + message_tracking_enabled=data.get("message_tracking_enabled", False), ) cls.cache[user_id] = instance @@ -253,6 +257,7 @@ async def add_empty(cls, user_id: int, cards: bool = True) -> None: "updates": False, "posts": False, }, + "message_tracking_enabled": False, } ) @@ -842,3 +847,28 @@ async def register_login(self) -> bool: self.achievements.append("logged_into_website") await self._update_val("achievements", self.achievements) return True + + async def toggle_message_tracking(self) -> bool: + """Toggles message tracking for the user""" + self.message_tracking_enabled = not self.message_tracking_enabled + await self._update_val("message_tracking_enabled", self.message_tracking_enabled) + + if not self.message_tracking_enabled: + await self._delete_message_tracking_data() + + return self.message_tracking_enabled + + async def _delete_message_tracking_data(self) -> None: + """Deletes all message tracking data for the user""" + from killua.utils.classes.guild import Guild + + result = await DB.guilds.update_many( + {f"message_stats.{self.id}": {"$exists": True}}, + {"$unset": {f"message_stats.{self.id}": ""}}, + ) + + for guild_id, guild in Guild.cache.items(): + if self.id in guild.message_stats: + del guild.message_stats[self.id] + + logging.info(f"Deleted message tracking data for user {self.id} in {result.modified_count} guild(s)") \ No newline at end of file diff --git a/killua/utils/test_db.py b/killua/utils/test_db.py index 952a6ecb1..e354b474a 100644 --- a/killua/utils/test_db.py +++ b/killua/utils/test_db.py @@ -1,5 +1,5 @@ from typing import Optional, List, Dict - +from random import randint class TestingDatabase: """A database class imitating pymongos collection classes""" @@ -15,129 +15,161 @@ def collection(self) -> str: self.db[self._collection] = [] return self._collection - # def _random_id(self) -> int: - # """Creates a random 8 digit number""" - # res = int(str(randint(0, 99999999)).zfill(8)) - # if res in [x["_id"] for x in self.db[self.collection]]: - # return self._random_id() - # else: - # return res + def _random_id(self) -> int: + """Creates a random 8 digit number""" + res = int(str(randint(0, 99999999)).zfill(8)) + ids = [x.get("_id") for x in self.db.get(self.collection, [])] + if res in ids: + return self._random_id() + return res def _normalize_dict(self, dictionary: dict) -> dict: - """Changes the {one.two: } to {one: {two: }}""" + """Changes the {one.two: } to {one: {two: }} for $set/$inc shape""" for _, d in dictionary.items(): if isinstance(d, dict): - for key, val in d.items(): + for key, val in list(d.items()): if "." in key: - k1 = key.split(".")[0] - k2 = key.split(".")[1] + k1, k2 = key.split(".", 1) + d.setdefault(k1, {}) d[k1][k2] = val del d[key] return dictionary + def _matches(self, doc: dict, where: dict) -> bool: + if not where: + return True + for k, v in where.items(): + if k not in doc: + return False + if isinstance(v, dict) and "$in" in v: + if doc[k] not in v["$in"]: + return False + else: + if doc[k] != v: + return False + return True + async def find_one(self, where: dict) -> Optional[dict]: coll = self.db[self.collection] for d in coll: - for key, value in d.items(): - if len([k for k, v in where.items() if k == key and v == value]) == len( - where - ): # When all conditions defined in "where" are met - return d + if self._matches(d, where): + return d async def find(self, where: dict) -> Optional[list]: coll = self.db[self.collection] - results = [] - - for d in coll: - for key, value in d.items(): - if [ - x - for x in list(where.values()) - if isinstance(x, dict) and "$in" in x.keys() - ]: - for k, v in [ - (k, v) - for k, v in list(where.items()) - if isinstance(v, dict) and "$in" in v.keys() - ]: - if k == key and value in v["$in"]: - results.append(d) - - elif len( - [k for k, v in where.items() if k == key and v == value] - ) == len( - where - ): # When all conditions defined in "where" are met - results.append(d) - - return results + return [d for d in coll if self._matches(d, where)] async def insert_one(self, object: dict) -> None: - self.db[self.collection].append(object) + obj = dict(object) # copy + if "_id" not in obj: + obj["_id"] = self._random_id() + self.db[self.collection].append(obj) async def insert_many(self, objects: List[dict]) -> None: for obj in objects: await self.insert_one(obj) - async def update_one(self, where: dict, update: Dict[str, dict]) -> dict: - # updated = False - operator = list(update.keys())[0] # This does not support multiple keys - - for v in update.values(): # Making sure it is all in the right format - v = self._normalize_dict(v) # lgtm [py/multiple-definition] + def _apply_update(self, item: dict, update: Dict[str, dict]) -> None: + operator = list(update.keys())[0] + + def _set_by_path(target: dict, dotted_key: str, value): + parts = dotted_key.split(".") + cur = target + for i, part in enumerate(parts): + if i == len(parts) - 1: + cur[part] = value + else: + cur.setdefault(part, {}) + cur = cur[part] + + if operator == "$set": + # Do not try to pick a single subkey; set the whole value, + # and support dotted paths. + for k, val in update[operator].items(): + _set_by_path(item, k, val) + + elif operator == "$push": + for k, val in update[operator].items(): + parts = k.split(".") + cur = item + for i, part in enumerate(parts): + if i == len(parts) - 1: + cur.setdefault(part, []) + cur[part].append(val) + else: + cur.setdefault(part, {}) + cur = cur[part] + + elif operator == "$pull": + for k, val in update[operator].items(): + parts = k.split(".") + cur = item + for i, part in enumerate(parts): + if i == len(parts) - 1: + if part in cur and isinstance(cur[part], list): + try: + cur[part].remove(val) + except ValueError: + pass + else: + cur = cur.get(part, {}) + if not isinstance(cur, dict): + break + + elif operator == "$inc": + for k, val in update[operator].items(): + parts = k.split(".") + cur = item + for i, part in enumerate(parts): + if i == len(parts) - 1: + cur[part] = (cur.get(part, 0) or 0) + val + else: + cur.setdefault(part, {}) + cur = cur[part] + + elif operator == "$unset": + for dotted_key, _ in update[operator].items(): + parts = dotted_key.split(".") + target = item + for i, part in enumerate(parts): + if i == len(parts) - 1: + if isinstance(target, dict) and part in target: + del target[part] + else: + target = target.get(part, {}) + if not isinstance(target, dict): + break + async def update_one(self, where: dict, update: Dict[str, dict]) -> dict: for p, item in enumerate(self.db[self.collection]): - for key, value in item.items(): - if len([k for k, v in where.items() if key == k and value == v]) == len( - where - ): - if operator == "$set": - for k, val in update[operator].items(): - if isinstance(val, dict): - self.db[self.collection][p][k][list(val.keys())[0]] = ( - list(val.values())[0] - ) - else: - self.db[self.collection][p][k] = val - if operator == "$push": - for k, val in update[operator].items(): - if isinstance(val, dict): - self.db[self.collection][p][k][ - list(val.keys())[0] - ].append(list(val.values())[0]) - else: - self.db[self.collection][p][k].append(val) - if operator == "$pull": - for k, val in update[operator].items(): - if isinstance(val, dict): - self.db[self.collection][p][k][ - list(val.keys())[0] - ].remove(list(val.values())[0]) - else: - self.db[self.collection][p][k].remove(val) - elif operator == "$inc": - for k, val in update[operator].items(): - if isinstance(val, dict): - self.db[self.collection][p][k][ - list(val.keys())[0] - ] += list(val.values())[0] - else: - self.db[self.collection][p][k] += val - # updated = True - - # if not updated: - # self.insert_one(update) - - return update # I only need this when the update would equal the object + if self._matches(item, where): + self._apply_update(self.db[self.collection][p], update) + return update + return update async def count_documents(self, where: dict = {}) -> int: return len(await self.find(where) or []) async def delete_one(self, where: dict) -> None: - ... # TODO: Implement this + coll = self.db[self.collection] + for i, d in enumerate(coll): + if self._matches(d, where): + del coll[i] + return async def delete_many(self, where: dict) -> None: - ... # TODO: Implement this + coll = self.db[self.collection] + self.db[self.collection] = [d for d in coll if not self._matches(d, where)] - async def update_many(self, where: dict, update: dict) -> None: - ... # TODO: Implement this + async def update_many(self, where: dict, update: dict) -> dict: + modified_count = 0 + for p, item in enumerate(self.db[self.collection]): + if self._matches(item, where): + self._apply_update(self.db[self.collection][p], update) + modified_count += 1 + + class UpdateManyResult: + def __init__(self, modified_count: int): + self.modified_count = modified_count + + return UpdateManyResult(modified_count) \ No newline at end of file diff --git a/killua/utils/views.py b/killua/utils/views.py new file mode 100644 index 000000000..9cbb8319e --- /dev/null +++ b/killua/utils/views.py @@ -0,0 +1,31 @@ +import discord +from discord import app_commands +from discord.ext import commands +from discord.ui import View, Button +from typing import Optional + +class ConfirmView(View): + """Simple confirmation view with Yes/No buttons""" + + def __init__(self, user_id: int): + super().__init__(timeout=30) + self.user_id = user_id + self.value = None + + @discord.ui.button(label="Confirm", style=discord.ButtonStyle.danger) + async def confirm(self, interaction: discord.Interaction, button: Button): + if interaction.user.id != self.user_id: + await interaction.response.send_message("You cannot interact with this confirmation.", ephemeral=True) + return + self.value = True + self.stop() + await interaction.response.edit_message(content="✅ Confirmed.", view=None) + + @discord.ui.button(label="Cancel", style=discord.ButtonStyle.secondary) + async def cancel(self, interaction: discord.Interaction, button: Button): + if interaction.user.id != self.user_id: + await interaction.response.send_message("You cannot interact with this confirmation.", ephemeral=True) + return + self.value = False + self.stop() + await interaction.response.edit_message(content="❌ Cancelled.", view=None) \ No newline at end of file