# Copyright 2021-2022 The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Tests for consumers."""
import asyncio
from unittest import mock
from unittest.mock import PropertyMock

from channels.db import database_sync_to_async
from channels.routing import URLRouter
from channels.testing import WebsocketCommunicator

from django.test import TestCase, TransactionTestCase
from django.utils import timezone

from debusine.db.models import Token, WorkRequest, Worker
from debusine.server import consumers
from debusine.server.consumers import WorkerConsumer
from debusine.server.routing import websocket_urlpatterns
from debusine.test import DatabaseHelpersMixin


class WebSocketURLTests(TestCase):
    """Tests for websocket_urlpatterns."""

    def test_websocket_urlpatterns(self):
        """Ensure websocket_urlpatterns provides a valid URLRouter config."""
        URLRouter(websocket_urlpatterns)


class WorkerConsumerMixin(DatabaseHelpersMixin):
    """Methods used by different WorkerConsumer tests classes."""

    SAMPLE_SBUILD_DATA = {
        "input": {
            "source_package_url": (
                "http://deb.debian.org/pool/pillow_8.1.2+dfsg-0.3.dsc"
            )
        },
        "distribution": "bullseye",
        "host_architecture": "amd64",
        "build_components": [
            "any",
            "all",
        ],
        "sbuild_options": [
            "--post-build-commands=/usr/local/bin/post-process %SBUILD_CHANGES",
        ],
    }

    REQUEST_DYNAMIC_METADATA_PAYLOAD = {
        "type": "websocket.send",
        "text": "request_dynamic_metadata",
    }

    WORK_REQUEST_AVAILABLE_PAYLOAD = {
        "type": "websocket.send",
        "text": "work_request_available",
    }

    @database_sync_to_async
    def acreate_worker(self):
        """Async version to return a new Worker."""
        return self.create_worker()

    def create_worker(self):
        """Return a new Worker."""
        token = self.create_token_enabled()

        worker = Worker.objects.create_with_fqdn("computer.lan", token)

        worker.set_dynamic_metadata(
            {
                "cpu_cores": 4,
                "sbuild:version": 1,
                "sbuild:chroots": "bullseye-amd64",
            }
        )

        return worker

    async def connect(self, *, token_key):
        """Return WebsocketCommunicator to ws/1.0/worker/connect/."""
        headers = [(b'User-Agent', b'A user agent')]

        if token_key:
            headers.append((b'token', token_key.encode('utf-8')))

        communicator = WebsocketCommunicator(
            consumers.WorkerConsumer.as_asgi(),
            'ws/1.0/worker/connect/',
            headers=headers,
        )

        connected, subprotocol = await communicator.connect()
        self.assertTrue(connected)

        return communicator

    async def connect_to_new_worker(self):
        """Return a communicator connected to a new worker."""
        worker = await self.acreate_worker()
        communicator = await self.connect(token_key=worker.token.key)

        # Assert that only a request for dynamic metadata is sent
        await self.assertRequestDynamicMetadata(communicator)
        self.assertTrue(await communicator.receive_nothing())

        await database_sync_to_async(worker.refresh_from_db)()
        self.assertTrue(worker.connected())

        return communicator, worker

    def assertTaskExists(self, task_name):  # pragma: no cover
        """Fail if task_name is not in asyncio.all_tasks()."""
        for task in asyncio.all_tasks():
            if task.get_name() == task_name:
                return

        self.fail(f"Asynchronous task '{task_name}' does not exist")

    def assertTaskNotExists(self, task_name):  # pragma: no cover
        """Fail if task_name is in asyncio.all_tasks()."""
        for task in asyncio.all_tasks():
            if task.get_name() == task_name:
                self.fail(f"Asynchronous task '{task_name}' does exist")

    async def assertRequestDynamicMetadata(self, communicator):
        """Assert the next communicator msg is request for dynamic metadata."""
        self.assertEqual(
            await communicator.receive_json_from(),
            self.REQUEST_DYNAMIC_METADATA_PAYLOAD,
        )

    def patch_workerconsumer_REQUEST_DYNAMIC_METADATA_SECONDS_fast(self):
        """Patch WorkerConsumer.REQUEST_DYNAMIC_METADATA_SECONDS to be fast."""
        patcher = mock.patch(
            'debusine.server.consumers.'
            'WorkerConsumer.REQUEST_DYNAMIC_METADATA_SECONDS',
            new_callable=PropertyMock,
            side_effect=[0, 3600],
        )

        sleep_until_dynamic_metadata_request_mock = patcher.start()
        self.addCleanup(patcher.stop)

        return sleep_until_dynamic_metadata_request_mock

    async def assert_client_rejected(
        self, token_key, reason, *, reason_code=None
    ):
        """Assert that the client is rejected with a reason."""
        communicator = await self.connect(token_key=token_key)

        msg = {
            'type': 'connection_rejected',
            'reason': reason,
        }

        if reason_code is not None:
            msg['reason_code'] = reason_code

        self.assertEqual(await communicator.receive_json_from(), msg)

        self.assertEqual(
            await communicator.receive_output(), {'type': 'websocket.close'}
        )


class WorkerConsumerTransactionTests(WorkerConsumerMixin, TransactionTestCase):
    """
    Tests for WorkerConsumer class.

    Some tests must be implemented in TransactionTestCase because of a bug
    using the database from a TestCase in async.

    See https://code.djangoproject.com/ticket/30448 and
    https://github.com/django/channels/issues/1091#issuecomment-701361358
    """

    async def test_disconnect(self):
        """Disconnect marks the Worker as disconnected."""
        communicator, worker = await self.connect_to_new_worker()

        await communicator.disconnect()

        await database_sync_to_async(worker.refresh_from_db)()

        self.assertFalse(worker.connected())

    async def assert_messages_send_on_connect(self, token_key, expected_msgs):
        """
        Connect authenticating with token_key and expects expected_msgs.

        The order of the received and expected_msgs is not enforced. The
        number of received and expected messages must match.
        """
        communicator = await self.connect(token_key=token_key)

        received_messages = []

        while True:
            try:
                received_messages.append(await communicator.receive_json_from())
            except asyncio.exceptions.TimeoutError:
                break

        self.assertEqual(len(expected_msgs), len(received_messages))

        for message in expected_msgs:
            self.assertIn(message, received_messages)

        await communicator.disconnect()

    async def test_connect_valid_token(self):
        """Connect succeeds and a request for dynamic metadata is received."""
        worker = await self.acreate_worker()

        await self.assert_messages_send_on_connect(
            worker.token.key, [self.REQUEST_DYNAMIC_METADATA_PAYLOAD]
        )

    async def test_connected_send_work_request_available_running(self):
        """Connect succeeds and consumer send expected two messages."""
        worker = await self.acreate_worker()

        await database_sync_to_async(WorkRequest.objects.create)(
            status=WorkRequest.Statuses.PENDING, worker=worker
        )

        await self.assert_messages_send_on_connect(
            worker.token.key,
            [
                self.REQUEST_DYNAMIC_METADATA_PAYLOAD,
                self.WORK_REQUEST_AVAILABLE_PAYLOAD,
            ],
        )

    async def test_connected_send_work_request_available_pending(self):
        """Connect succeeds and consumer send expected two messages."""
        worker = await self.acreate_worker()

        await database_sync_to_async(WorkRequest.objects.create)(
            status=WorkRequest.Statuses.RUNNING, worker=worker
        )

        await self.assert_messages_send_on_connect(
            worker.token.key,
            [
                self.REQUEST_DYNAMIC_METADATA_PAYLOAD,
                self.WORK_REQUEST_AVAILABLE_PAYLOAD,
            ],
        )

    async def test_connected_token_without_worker(self):
        """
        Connect succeeds and an error message is returned.

        Reason of disconnection: token did not have a worker associated.
        """
        token = await database_sync_to_async(Token.objects.create)()

        await database_sync_to_async(token.enable)()

        # Assert that the server returns that the token does not exist
        await self.assert_client_rejected(
            token.key,
            f'Token without associated worker (token: "{token.key}")',
        )

    async def test_connected_token_not_enabled(self):
        """
        Connect succeeds and an error message is returned.

        Reason of disconnection: token was not enabled.
        """
        token = await database_sync_to_async(Token.objects.create)()
        await database_sync_to_async(Worker.objects.create)(
            registered_at=timezone.now(), token=token
        )

        await self.assert_client_rejected(
            token.key,
            f'Token is not enabled (token: "{token.key}")',
            reason_code='TOKEN_DISABLED',
        )

    async def test_disconnect_cancel_request_dynamic_metadata(self):
        """Disconnect cancels the dynamic metadata refresher."""
        self.assertTaskNotExists('dynamic_metadata_refresher')

        communicator, worker = await self.connect_to_new_worker()

        self.assertTaskExists('dynamic_metadata_refresher')

        await communicator.disconnect()

        self.assertTaskNotExists('dynamic_metadata_refresher')

    async def test_disconnect_leaves_work_request_as_running(self):
        """Worker disconnect does not interrupt the associated WorkRequest."""
        communicator, worker = await self.connect_to_new_worker()

        work_request: WorkRequest = await database_sync_to_async(
            lambda: WorkRequest.objects.create(task_name="sbuild")
        )()

        await database_sync_to_async(
            lambda: work_request.assign_worker(worker)
        )()
        await database_sync_to_async(work_request.mark_running)()

        self.assertEqual(work_request.status, work_request.Statuses.RUNNING)

        await communicator.disconnect()

        await database_sync_to_async(worker.refresh_from_db)()

        await database_sync_to_async(work_request.refresh_from_db)()

        self.assertFalse(worker.connected())

        # Assert that work_request is left as running. The Worker
        # is expected to re-connect and pick it up or to update the
        # status to COMPLETED (via the API).
        #
        # During the execution of a task, debusine-server will send
        # pings to the debusine-worker (via daphne defaults). debusine-worker
        # will not be able to respond until the task has finished.
        # This causes debusine server to disconnect the worker during
        # the execution of the task.
        #
        # If the status of the work_request changed from RUNNING to
        # PENDING (for example) when debusine-worker tries to update
        # the status to COMPLETED, debusine-server would reject it because
        # a work-request cannot transition from PENDING to COMPLETED.
        self.assertEqual(work_request.status, work_request.Statuses.RUNNING)

    async def test_request_dynamic_metadata_after_connect(self):
        """Debusine sends request_dynamic_metadata_updated after connect."""
        worker = await self.acreate_worker()

        self.patch_workerconsumer_REQUEST_DYNAMIC_METADATA_SECONDS_fast()

        communicator = await self.connect(token_key=worker.token.key)

        # debusine sends a request for dynamic metadata because the worker
        # just connected
        await self.assertRequestDynamicMetadata(communicator)

        # sends another one because in this test the method
        # self.patch_workerconsumer_REQUEST_DYNAMIC_METADATA_SECONDS_fast()
        # was called and there is no waiting time
        await self.assertRequestDynamicMetadata(communicator)

        # Nothing else is received for now (the next request for dynamic
        # metadata would happen in 3600 seconds, as per
        # self.patch_workerconsumer_REQUEST_DYNAMIC_METADATA_SECONDS_fast()
        self.assertTrue(await communicator.receive_nothing())

    async def test_connect_redis_not_available(self):
        """
        Connect succeeds and an error message is returned.

        For example the Redis server was not available at the time of
        connection.
        """
        patcher = mock.patch(
            'debusine.server.consumers.WorkerConsumer._channel_layer_group_add',
            side_effect=OSError,
        )
        mocked = patcher.start()
        self.addCleanup(patcher.stop)

        token = await database_sync_to_async(self.create_token_enabled)()
        await database_sync_to_async(Worker.objects.create)(
            registered_at=timezone.now(), token=token
        )

        await self.assert_client_rejected(token.key, 'Service unavailable')
        mocked.assert_called()

    async def test_worker_disabled(self):
        """Worker is disabled and msg disconnection msg sent to the worker."""
        communicator, worker = await self.connect_to_new_worker()

        await communicator.send_input({"type": "worker.disabled"})

        key = await database_sync_to_async(lambda: worker.token.key)()

        self.assertEqual(
            await communicator.receive_json_from(),
            {
                "type": "connection_closed",
                "reason": f'Token has been disabled (token: "{key}")',
                "reason_code": "TOKEN_DISABLED",
            },
        )

    async def test_work_request_assigned(self):
        """Assert it sends message to the worker."""
        communicator, worker = await self.connect_to_new_worker()

        def assign_work_request(worker):
            work_request = WorkRequest.objects.create(task_name="sbuild")
            work_request.assign_worker(worker)

        await database_sync_to_async(assign_work_request)(worker)

        self.assertEqual(
            await communicator.receive_json_from(),
            {"text": "work_request_available", "type": "websocket.send"},
        )


class WorkerConsumerTests(WorkerConsumerMixin, TestCase):
    """Tests for the WorkerConsumer class."""

    async def test_connect_invalid_token(self):
        """
        Connect succeeds and an error message is returned.

        Reason of disconnection: The token does not exist in the database
        """
        await self.assert_client_rejected(
            'does-not-exist', 'Token not found: "does-not-exist"'
        )

    async def test_connect_without_token(self):
        """
        Connect succeeds and an error message is returned.

        Reason of disconnection: missing 'token' header.
        """
        await self.assert_client_rejected(
            None, 'Missing required header: "token"'
        )

    def test_get_header_value_key_found(self):
        """WorkerConsumer._get_header_value return the key (str)."""
        worker = WorkerConsumer()

        token = '1fb371ea69dca7b'

        worker.scope = {'headers': [(b'token', token.encode('utf-8'))]}

        self.assertEqual(worker._get_header_value('token'), token)

    def test_get_header_value_key_not_found(self):
        """WorkerConsumer._get_header_value return None when key not found."""
        worker = WorkerConsumer()

        worker.scope = {'headers': []}

        self.assertIsNone(worker._get_header_value('token'))

    async def test_disconnect_while_not_connected(self):
        """
        WorkerConsumer.disconnect returns without exceptions.

        Check that WorkerConsumer.disconnect() does not raise exceptions
        when called with a worker that has not connected.
        """
        worker = WorkerConsumer()
        await worker.disconnect(0)
