# Copyright (C) 2024 by Animath
# SPDX-License-Identifier: GPL-3.0-or-later

from channels.generic.websocket import AsyncJsonWebsocketConsumer
from django.contrib.auth.models import User
from participation.models import Team, Pool, Tournament
from registration.models import Registration

from .models import Channel, Message


class ChatConsumer(AsyncJsonWebsocketConsumer):
    """
    This consumer manages the websocket of the chat interface.
    """
    async def connect(self) -> None:
        """
        This function is called when a new websocket is trying to connect to the server.
        We accept only if this is a user of a team of the associated tournament, or a volunteer
        of the tournament.
        """
        if '_fake_user_id' in self.scope['session']:
            self.scope['user'] = await User.objects.aget(pk=self.scope['session']['_fake_user_id'])

        # Fetch the registration of the current user
        user = self.scope['user']
        if user.is_anonymous:
            # User is not authenticated
            await self.close()
            return

        reg = await Registration.objects.aget(user_id=user.id)
        self.registration = reg

        # Accept the connection
        await self.accept()

        channels = await Channel.get_accessible_channels(user, 'read')
        async for channel in channels.all():
            await self.channel_layer.group_add(f"chat-{channel.id}", self.channel_name)

    async def disconnect(self, close_code) -> None:
        """
        Called when the websocket got disconnected, for any reason.
        :param close_code: The error code.
        """
        if self.scope['user'].is_anonymous:
            # User is not authenticated
            return

        channels = await Channel.get_accessible_channels(self.scope['user'], 'read')
        async for channel in channels.all():
            await self.channel_layer.group_discard(f"chat-{channel.id}", self.channel_name)

    async def receive_json(self, content, **kwargs):
        """
        Called when the client sends us some data, parsed as JSON.
        :param content: The sent data, decoded from JSON text. Must content a `type` field.
        """
        match content['type']:
            case 'fetch_channels':
                await self.fetch_channels()
            case 'send_message':
                await self.receive_message(content)
            case 'fetch_messages':
                await self.fetch_messages(**content)
            case unknown:
                print("Unknown message type:", unknown)

    async def fetch_channels(self) -> None:
        user = self.scope['user']

        read_channels = await Channel.get_accessible_channels(user, 'read')
        write_channels = await Channel.get_accessible_channels(user, 'write')
        message = {
            'type': 'fetch_channels',
            'channels': [
                {
                    'id': channel.id,
                    'name': channel.name,
                    'read_access': True,
                    'write_access': await write_channels.acontains(channel),
                }
                async for channel in read_channels.all()
            ]
        }
        await self.send_json(message)

    async def receive_message(self, message: dict) -> None:
        user = self.scope['user']
        channel = await Channel.objects.prefetch_related('tournament__pools__juries', 'pool', 'team', 'invited') \
            .aget(id=message['channel_id'])
        write_channels = await Channel.get_accessible_channels(user, 'write')
        if not await write_channels.acontains(channel):
            return

        message = await Message.objects.acreate(
            author=user,
            channel=channel,
            content=message['content'],
        )

        await self.channel_layer.group_send(f'chat-{channel.id}', {
            'type': 'chat.send_message',
            'id': message.id,
            'channel_id': channel.id,
            'timestamp': message.created_at.isoformat(),
            'author': await message.aget_author_name(),
            'content': message.content,
        })

    async def fetch_messages(self, channel_id: int, offset: int = 0, limit: int = 50, **_kwargs) -> None:
        channel = await Channel.objects.aget(id=channel_id)
        read_channels = await Channel.get_accessible_channels(self.scope['user'], 'read')
        if not await read_channels.acontains(channel):
            return

        limit = min(limit, 200)  # Fetch only maximum 200 messages at the time

        messages = Message.objects.filter(channel=channel).order_by('-created_at')[offset:offset + limit].all()
        await self.send_json({
            'type': 'fetch_messages',
            'channel_id': channel_id,
            'messages': list(reversed([
                {
                    'id': message.id,
                    'timestamp': message.created_at.isoformat(),
                    'author': await message.aget_author_name(),
                    'content': message.content,
                }
                async for message in messages
            ]))
        })

    async def chat_send_message(self, message) -> None:
        await self.send_json({'type': 'send_message', 'id': message['id'], 'channel_id': message['channel_id'],
                              'timestamp': message['timestamp'], 'author': message['author'],
                              'content': message['content']})