# Copyright 2021-2022 The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Consumers for the server application."""
import asyncio
import json
import logging
from typing import Optional

from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncWebsocketConsumer

from debusine.db.models import Token, WorkRequest, Worker

logger = logging.getLogger(__name__)


class WorkerConsumer(AsyncWebsocketConsumer):
    """
    Implement server-side of a Worker.

    After the client worker connects to the server: WorkerConsumer
    requests information, send tasks, etc.
    """

    # Call self._send_dynamic_metadata_request() every
    # REQUEST_DYNAMIC_METADATA_SECONDS
    REQUEST_DYNAMIC_METADATA_SECONDS = 3600

    def __init__(self, *args, **kwargs):
        """Initialise WorkerConsumer member variables."""
        super().__init__(*args, **kwargs)
        self._worker: Optional[Worker] = None
        self._request_dynamic_metadata_task = None

    async def _channel_layer_group_add(self, token_key, channel_name):
        await self.channel_layer.group_add(token_key, channel_name)

    async def connect(self):
        """Worker client is connecting."""
        await self.accept()

        if not await self._has_permission_connect():
            return

        try:
            await self._channel_layer_group_add(
                self._worker.token.key, self.channel_name
            )
        except Exception as exc:
            logger.error(  # noqa: G200
                'Error adding worker to group (Redis): %s', exc
            )
            await self._reject_connection('Service unavailable')
            return False

        logger.info("Worker connected: %s", self._worker.name)

        await database_sync_to_async(self._worker.mark_connected)()

        self._request_dynamic_metadata_task = asyncio.create_task(
            self._dynamic_metadata_refresher(),
            name='dynamic_metadata_refresher',
        )

        # If there are WorkRequests running or pending: send to the worker
        # "work-request-available" message
        if await database_sync_to_async(
            WorkRequest.objects.running_or_pending_exists
        )(self._worker):
            await self._send_work_request_available()

    async def _has_permission_connect(self):
        token_key = self._get_header_value('token')
        if token_key is None:
            await self._reject_connection('Missing required header: "token"')
            return False

        token = await database_sync_to_async(Token.objects.get_token_or_none)(
            token_key
        )
        if token is None:
            await self._reject_connection(f'Token not found: "{token_key}"')
            return False

        if not token.enabled:
            await self._reject_connection(
                f'Token is not enabled (token: "{token_key}")',
                reason_code='TOKEN_DISABLED',
            )
            return False

        self._worker = getattr(token, 'worker', None)

        if self._worker is None:
            await self._reject_connection(
                f'Token without associated worker (token: "{token_key}")'
            )
            return False

        logger.debug(
            'Worker can connect (name: %s, token: %s)',
            self._worker.name,
            token_key,
        )
        return True

    async def _reject_connection(self, reason, reason_code=None):
        msg = {
            'type': 'connection_rejected',
            'reason': reason,
        }

        if reason_code is not None:
            msg['reason_code'] = reason_code

        await self.send(text_data=json.dumps(msg), close=True)

        logger.info('Worker rejected. %s', reason)

    async def _send_work_request_available(self):
        await self.send(
            text_data=json.dumps(
                {
                    "type": "websocket.send",
                    "text": "work_request_available",
                }
            )
        )

    async def _dynamic_metadata_refresher(self):
        while True:
            await self._send_dynamic_metadata_request()
            await asyncio.sleep(self.REQUEST_DYNAMIC_METADATA_SECONDS)

    async def _send_dynamic_metadata_request(self):
        await self.send(
            text_data=json.dumps(
                {
                    "type": "websocket.send",
                    "text": "request_dynamic_metadata",
                }
            )
        )

    async def disconnect(self, close_code):  # noqa: U100
        """Worker has disconnected. Cancel tasks, mark as disconnect, etc."""
        if self._request_dynamic_metadata_task:
            self._request_dynamic_metadata_task.cancel()

        if self._worker:
            await self.channel_layer.group_discard(
                self._worker.token.key, self.channel_name
            )

            await database_sync_to_async(self._worker.mark_disconnected)()
            logger.info(
                "Worker disconnected: %s (code: %s)",
                self._worker.name,
                close_code,
            )

    def _get_header_value(self, header_field_name: str) -> Optional[str]:
        """
        Return the value of header_field_name from self.scope['headers'].

        :param header_field_name: case-insensitive, utf-8 encoded.
        :return: None if header_field_name is not found in
          self.scope['headers'] or the header's content.
        """
        encoded_header_field_name = header_field_name.lower().encode('utf-8')

        for name, value in self.scope['headers']:
            if name == encoded_header_field_name:
                return value.decode('utf-8')

        return None

    async def worker_disabled(self, event):  # noqa: U100
        """Worker has been disabled. Send a connection_closed msg."""
        logger.info("Worker %s disabled", self._worker.name)

        msg = {
            "type": "connection_closed",
            "reason": 'Token has been disabled '
            f'(token: "{self._worker.token.key}")',
            "reason_code": "TOKEN_DISABLED",
        }
        await self.send(text_data=json.dumps(msg), close=True)

    async def work_request_assigned(self, event):  # noqa: U100
        """Work Request has been assigned to the worker. Send channel msg."""
        await self._send_work_request_available()
