# Copyright 2019, 2021-2022 The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Unit tests for the models."""
from datetime import datetime

from django.test import TestCase
from django.utils import timezone

from debusine.db.models import (
    Token,
    WorkRequest,
    Worker,
    WorkerManager,
)
from debusine.test.django import ChannelsHelpersMixin


class TokenTests(TestCase, ChannelsHelpersMixin):
    """Unit tests of the ``Token`` model."""

    def test_save(self):
        """The model creates a Token.key on save or keeps it if it existed."""
        token = Token.objects.create(owner='John')

        self.assertIsNotNone(token.id)
        self.assertEqual(len(token.key), 64)

        key = token.key
        token.save()
        self.assertEqual(token.key, key)

    def test_str(self):
        """Test Token.__str__."""
        token = Token.objects.create(owner='John')
        self.assertEqual(token.__str__(), token.key)

    def test_get_token_or_none_found(self):
        """get_token_or_none looks up a token and returns it."""
        token = Token.objects.create(key='some_key')

        self.assertEqual(token, Token.objects.get_token_or_none('some_key'))

    def test_get_token_or_none_not_found(self):
        """get_token_or_none cannot find a token and returns None."""
        self.assertIsNone(Token.objects.get_token_or_none('a_non_existing_key'))

    def test_enable(self):
        """enable() enables the token."""
        token = Token.objects.create()

        # Assert the default is disabled tokens
        self.assertFalse(token.enabled)

        token.enable()
        token.refresh_from_db()

        self.assertTrue(token.enabled)

    def test_disable(self):
        """disable() disables the token."""
        token = Token.objects.create(enabled=True)
        Worker.objects.create(token=token, registered_at=timezone.now())

        channel = self.create_channel(token.key)

        token.disable()

        self.assert_channel_received(channel, {"type": "worker.disabled"})
        token.refresh_from_db()

        self.assertFalse(token.enabled)


class TokenManagerTests(TestCase):
    """Unit tests for the ``TokenManager`` class."""

    @classmethod
    def setUpTestData(cls):
        """Test data used by all the tests."""
        cls.token_john = Token.objects.create(owner='John')
        cls.token_bev = Token.objects.create(owner='Bev')

    def test_get_tokens_all(self):
        """get_tokens returns all the tokens if no filter is applied."""
        self.assertQuerysetEqual(
            Token.objects.get_tokens().order_by('owner'),
            [self.token_bev, self.token_john],
        )

    def test_get_tokens_by_owner(self):
        """get_tokens returns the correct tokens when filtering by owner."""
        self.assertQuerysetEqual(
            Token.objects.get_tokens(owner='John'), [self.token_john]
        )
        self.assertQuerysetEqual(
            Token.objects.get_tokens(owner='Bev'), [self.token_bev]
        )
        self.assertQuerysetEqual(Token.objects.get_tokens(owner='Someone'), [])

    def test_get_tokens_by_key(self):
        """get_tokens returns the correct tokens when filtering by key."""
        self.assertQuerysetEqual(
            Token.objects.get_tokens(key=self.token_john.key),
            [self.token_john],
        )
        self.assertQuerysetEqual(
            Token.objects.get_tokens(key='non-existing-key'), []
        )

    def test_get_tokens_by_key_owner_empty(self):
        """
        get_tokens returns nothing if using a key and owner of different keys.

        Key for the key parameter or owner for the owner parameter exist
        but are for different tokens.
        """
        self.assertQuerysetEqual(
            Token.objects.get_tokens(
                key=self.token_john.key, owner=self.token_bev.owner
            ),
            [],
        )


class WorkerManagerTests(TestCase):
    """Tests for the WorkerManager."""

    def test_connected(self):
        """WorkerManager.connected() return the connected Workers."""
        worker_connected = Worker.objects.create_with_fqdn(
            'connected-worker', Token.objects.create()
        )

        worker_connected.mark_connected()

        Worker.objects.create_with_fqdn(
            'not-connected-worker',
            Token.objects.create(),
        )

        self.assertQuerysetEqual(Worker.objects.connected(), [worker_connected])

    def test_waiting_for_work_request(self):
        """Test WorkerManager.waiting_for_work_request() return a Worker."""
        worker = Worker.objects.create_with_fqdn(
            'worker-a', Token.objects.create(enabled=True)
        )

        # WorkerManagement.waiting_for_work_request: returns no workers because
        # the worker is not connected
        self.assertQuerysetEqual(Worker.objects.waiting_for_work_request(), [])

        worker.mark_connected()

        # Now the Worker is ready to have a task assigned
        self.assertQuerysetEqual(
            Worker.objects.waiting_for_work_request(), [worker]
        )

        # A task is assigned to the worker
        work_request = WorkRequest.objects.create(worker=worker)

        # The worker is not ready (it is busy with a task assigned)
        self.assertQuerysetEqual(Worker.objects.waiting_for_work_request(), [])

        # The task finished
        work_request.status = WorkRequest.Statuses.COMPLETED
        work_request.save()

        # The worker is ready: the WorkRequest that had assigned finished
        self.assertQuerysetEqual(
            Worker.objects.waiting_for_work_request(), [worker]
        )

    def test_waiting_for_work_request_no_return_disabled_workers(self):
        """Test WorkerManager.waiting_for_work_request() no return disabled."""
        worker_enabled = Worker.objects.create_with_fqdn(
            "worker-enabled", Token.objects.create(enabled=True)
        )
        worker_enabled.mark_connected()

        worker_disabled = Worker.objects.create_with_fqdn(
            "worker-disabled", Token.objects.create(enabled=False)
        )
        worker_disabled.mark_connected()

        self.assertQuerysetEqual(
            Worker.objects.waiting_for_work_request(), [worker_enabled]
        )

    def test_create_with_fqdn_new_fqdn(self):
        """WorkerManager.create_with_fqdn() return a worker."""
        token = Token.objects.create()
        worker = Worker.objects.create_with_fqdn(
            'a-new-and-unique-name', token=token
        )

        self.assertEqual(worker.name, 'a-new-and-unique-name')
        self.assertEqual(worker.token, token)
        self.assertIsNotNone(worker.pk)

    def test_create_with_fqdn_duplicate_fqdn(self):
        """
        WorkerManager.create_with_fqdn() return a worker.

        The name ends with -2 because 'connected-worker' is already used.
        """
        Worker.objects.create_with_fqdn(
            'connected-worker', token=Token.objects.create()
        )

        token = Token.objects.create()
        worker = Worker.objects.create_with_fqdn(
            'connected-worker', token=token
        )

        self.assertEqual(worker.name, 'connected-worker-2')
        self.assertEqual(worker.token, token)
        self.assertIsNotNone(worker.pk)

    def test_slugify_with_suffix_counter_1(self):
        """WorkerManager._generate_unique_name does not append '-1'."""
        self.assertEqual(
            WorkerManager._generate_unique_name('worker.lan', 1), 'worker-lan'
        )

    def test_slugify_with_suffix_counter_3(self):
        """WorkerManager._generate_unique_name appends '-3'."""
        self.assertEqual(
            WorkerManager._generate_unique_name('worker.lan', 3), 'worker-lan-3'
        )

    def test_get_worker_by_token_or_none_return_none(self):
        """WorkerManager.get_worker_by_token_or_none() return None."""
        self.assertIsNone(
            Worker.objects.get_worker_by_token_key_or_none('non-existing-key')
        )

    def test_get_worker_by_token_or_none_return_worker(self):
        """WorkerManager.get_worker_by_token_or_none() return the Worker."""
        token = Token.objects.create()

        worker = Worker.objects.create_with_fqdn('worker-a', token)

        self.assertEqual(
            Worker.objects.get_worker_by_token_key_or_none(token.key), worker
        )

    def test_get_worker_or_none_return_worker(self):
        """WorkerManager.get_worker_or_none() return the Worker."""
        token = Token.objects.create()

        worker = Worker.objects.create_with_fqdn('worker-a', token)

        self.assertEqual(Worker.objects.get_worker_or_none('worker-a'), worker)

    def test_get_worker_or_none_return_none(self):
        """WorkerManager.get_worker_or_none() return None."""
        self.assertIsNone(Worker.objects.get_worker_or_none('does-not-exist'))


class WorkerTests(TestCase):
    """Tests for the Worker model."""

    @classmethod
    def setUpTestData(cls):
        """Set up the Worker for the tests."""
        cls.worker = Worker.objects.create_with_fqdn(
            "computer.lan", Token.objects.create()
        )
        cls.worker.static_metadata = {"os": "debian"}
        cls.worker.set_dynamic_metadata({"cpu_cores": "4"})
        cls.worker.save()

    def test_mark_connected(self):
        """Test mark_connect method."""
        time_before = timezone.now()
        self.assertIsNone(self.worker.connected_at)

        self.worker.mark_connected()

        self.assertGreaterEqual(self.worker.connected_at, time_before)
        self.assertLessEqual(self.worker.connected_at, timezone.now())

    def test_mark_disconnected(self):
        """Test mark_disconnected method."""
        self.worker.mark_connected()

        self.assertTrue(self.worker.connected())
        self.worker.mark_disconnected()

        self.assertFalse(self.worker.connected())
        self.assertIsNone(self.worker.connected_at)

    def test_connected(self):
        """Test connected method."""
        self.assertFalse(self.worker.connected())

        self.worker.connected_at = timezone.now()

        self.assertTrue(self.worker.connected())

    def test_metadata_no_conflict(self):
        """Test metadata method: return all the metadata."""
        self.assertEqual(
            self.worker.metadata(), {'cpu_cores': '4', 'os': 'debian'}
        )

    def test_metadata_with_conflict(self):
        """
        Test metadata method: return all the metadata.

        static_metadata has priority over dynamic_metadata
        """
        # Assert initial state
        self.assertEqual(self.worker.dynamic_metadata['cpu_cores'], '4')

        # Add new static_metadata key
        self.worker.static_metadata['cpu_cores'] = '8'

        self.assertEqual(
            self.worker.metadata(), {'cpu_cores': '8', 'os': 'debian'}
        )

    def test_metadata_is_deep_copy(self):
        """Test metadata does a deep copy."""
        self.worker.dynamic_metadata['schroots'] = ['buster', 'bullseye']
        self.worker.metadata()['schroots'].append('bookworm')

        self.assertEqual(
            self.worker.dynamic_metadata['schroots'], ['buster', 'bullseye']
        )

    def test_set_dynamic_metadata(self):
        """Worker.set_dynamic_metadata sets the dynamic metadata."""
        self.worker.dynamic_metadata = {}
        self.worker.dynamic_metadata_updated_at = None
        self.worker.save()

        dynamic_metadata = {"cpu_cores": 4, "ram": "16"}
        self.worker.set_dynamic_metadata(dynamic_metadata)

        self.worker.refresh_from_db()

        self.assertEqual(self.worker.dynamic_metadata, dynamic_metadata)
        self.assertLessEqual(
            self.worker.dynamic_metadata_updated_at, timezone.now()
        )

    def test_str(self):
        """Test WorkerTests.__str__."""
        self.assertEqual(self.worker.__str__(), 'computer-lan')


class WorkRequestManagerTests(TestCase):
    """Tests for WorkRequestManager."""

    def test_pending(self):
        """WorkRequestManager.pending() returns pending WorkRequests."""
        work_request_1 = WorkRequest.objects.create(
            status=WorkRequest.Statuses.PENDING
        )

        work_request_2 = WorkRequest.objects.create(
            status=WorkRequest.Statuses.PENDING
        )

        WorkRequest.objects.create(status=WorkRequest.Statuses.ABORTED)

        self.assertQuerysetEqual(
            WorkRequest.objects.pending(), [work_request_1, work_request_2]
        )

        work_request_1.created_at = timezone.now()
        work_request_1.save()

        self.assertQuerysetEqual(
            WorkRequest.objects.pending(), [work_request_2, work_request_1]
        )

    def test_pending_filter_by_worker(self):
        """WorkRequestManager.pending() returns WorkRequest for the worker."""
        WorkRequest.objects.create(
            status=WorkRequest.Statuses.PENDING, task_name="sbuild"
        )

        worker = Worker.objects.create_with_fqdn(
            "computer.lan", Token.objects.create()
        )

        work_request_2 = WorkRequest.objects.create(
            status=WorkRequest.Statuses.PENDING, worker=worker
        )

        self.assertQuerysetEqual(
            WorkRequest.objects.pending(worker=worker), [work_request_2]
        )

    def test_raise_value_error_exclude_assigned_and_worker(self):
        """WorkRequestManager.pending() raises ValueError."""
        worker = Worker.objects.create_with_fqdn(
            "computer.lan", Token.objects.create()
        )

        with self.assertRaisesRegex(
            ValueError, "Cannot exclude_assigned and filter by worker"
        ):
            WorkRequest.objects.pending(exclude_assigned=True, worker=worker)

    def test_pending_exclude_assigned(self):
        """
        Test WorkRequestManager.pending(exclude_assigned=True).

        It excludes work requests that are assigned to a worker.
        """
        # Pending, not assigned to a worker WorkRequest
        work_request = WorkRequest.objects.create(
            status=WorkRequest.Statuses.PENDING, task_name="sbuild"
        )

        # Is returned as expected
        self.assertQuerysetEqual(
            WorkRequest.objects.pending(exclude_assigned=True), [work_request]
        )

        # Creates a worker
        worker = Worker.objects.create_with_fqdn(
            'test', token=Token.objects.create()
        )

        # Assigns the worker to the work_request
        work_request.worker = worker
        work_request.save()

        # pending(exclude_assigned=True) doesn't return it anymore
        self.assertQuerysetEqual(
            WorkRequest.objects.pending(exclude_assigned=True), []
        )

        # without the exclude_assigned it returns it
        self.assertQuerysetEqual(
            WorkRequest.objects.pending(exclude_assigned=False), [work_request]
        )

    def test_running(self):
        """WorkRequestManager.running() returns running WorkRequests."""
        work_request = WorkRequest.objects.create(
            status=WorkRequest.Statuses.RUNNING
        )
        WorkRequest.objects.create(status=WorkRequest.Statuses.ABORTED)

        self.assertQuerysetEqual(WorkRequest.objects.running(), [work_request])

    def test_running_or_pending_exists(self):
        """Generic test for WorkRequestManager.running_or_pending_exists()."""
        worker = Worker.objects.create_with_fqdn(
            "test", token=Token.objects.create()
        )
        self.assertFalse(WorkRequest.objects.running_or_pending_exists(worker))

        work_request = WorkRequest.objects.create(worker=worker)

        self.assertEqual(work_request.status, WorkRequest.Statuses.PENDING)
        self.assertTrue(WorkRequest.objects.running_or_pending_exists(worker))

        work_request.assign_worker(worker)

        self.assertTrue(WorkRequest.objects.running_or_pending_exists(worker))

        work_request.mark_running()

        self.assertTrue(WorkRequest.objects.running_or_pending_exists(worker))

        work_request.mark_aborted()

        self.assertFalse(WorkRequest.objects.running_or_pending_exists(worker))

    def test_completed(self):
        """WorkRequestManager.completed() returns completed WorkRequests."""
        work_request = WorkRequest.objects.create(
            status=WorkRequest.Statuses.COMPLETED
        )
        WorkRequest.objects.create(status=WorkRequest.Statuses.RUNNING)

        self.assertQuerysetEqual(
            WorkRequest.objects.completed(), [work_request]
        )

    def test_aborted(self):
        """WorkRequestManager.aborted() returns aborted WorkRequests."""
        work_request = WorkRequest.objects.create(
            status=WorkRequest.Statuses.ABORTED
        )
        WorkRequest.objects.create(status=WorkRequest.Statuses.RUNNING)

        self.assertQuerysetEqual(WorkRequest.objects.aborted(), [work_request])


class WorkRequestTests(TestCase, ChannelsHelpersMixin):
    """Tests for the WorkRequest class."""

    def setUp(self):
        """Set up WorkRequest to be used in the tests."""
        worker = Worker.objects.create_with_fqdn(
            "computer.lan", Token.objects.create()
        )

        self.work_request = WorkRequest.objects.create(
            task_name='request-01',
            worker=worker,
            status=WorkRequest.Statuses.PENDING,
        )

    def test_str(self):
        """Test WorkerRequest.__str__ return WorkRequest.task_name."""
        self.assertEqual(self.work_request.__str__(), 'request-01')

    def test_mark_running_from_aborted(self):
        """Test WorkRequest.mark_running() doesn't change (was aborted)."""
        self.work_request.status = WorkRequest.Statuses.ABORTED
        self.work_request.save()

        self.assertFalse(self.work_request.mark_running())

        self.work_request.refresh_from_db()
        self.assertEqual(self.work_request.status, WorkRequest.Statuses.ABORTED)

    def test_mark_running(self):
        """Test WorkRequest.mark_running() change status to running."""
        self.work_request.status = WorkRequest.Statuses.PENDING
        self.work_request.save()
        self.assertIsNone(self.work_request.started_at)

        self.assertTrue(self.work_request.mark_running())

        self.work_request.refresh_from_db()
        self.assertEqual(self.work_request.status, WorkRequest.Statuses.RUNNING)
        self.assertLess(self.work_request.started_at, timezone.now())

        # Marking as running again (a running WorkRequest) is a no-op
        started_at = self.work_request.started_at
        self.assertTrue(self.work_request.mark_running())

        self.work_request.refresh_from_db()
        self.assertEqual(self.work_request.status, WorkRequest.Statuses.RUNNING)
        self.assertEqual(self.work_request.started_at, started_at)

    def test_mark_running_fails_worker_already_running(self):
        """WorkRequest.mark_running() return False: worker already running."""
        self.work_request.status = WorkRequest.Statuses.PENDING
        self.work_request.save()

        WorkRequest.objects.create(
            status=WorkRequest.Statuses.RUNNING, worker=self.work_request.worker
        )

        self.assertFalse(self.work_request.mark_running())

    def test_mark_running_fails_no_assigned_worker(self):
        """WorkRequest.mark_running() return False: no worker assigned."""
        self.work_request.status = WorkRequest.Statuses.PENDING
        self.work_request.worker = None
        self.work_request.save()

        self.assertFalse(self.work_request.mark_running())

    def test_mark_completed_from_aborted(self):
        """Test WorkRequest.mark_completed() doesn't change (was aborted)."""
        self.work_request.status = WorkRequest.Statuses.ABORTED
        self.work_request.save()

        self.work_request.refresh_from_db()

        self.assertFalse(
            self.work_request.mark_completed(WorkRequest.Results.SUCCESS)
        )

        self.work_request.refresh_from_db()
        self.assertEqual(self.work_request.status, WorkRequest.Statuses.ABORTED)

    def test_mark_completed(self):
        """Test WorkRequest.mark_completed() changes status to completed."""
        self.work_request.status = WorkRequest.Statuses.RUNNING
        self.work_request.save()

        self.assertIsNone(self.work_request.completed_at)
        self.assertEqual(self.work_request.result, WorkRequest.Results.NONE)

        self.assertTrue(
            self.work_request.mark_completed(WorkRequest.Results.SUCCESS)
        )

        self.work_request.refresh_from_db()
        self.assertEqual(
            self.work_request.status, WorkRequest.Statuses.COMPLETED
        )
        self.assertEqual(self.work_request.result, WorkRequest.Results.SUCCESS)
        self.assertLess(self.work_request.completed_at, timezone.now())

    def test_mark_aborted(self):
        """Test WorkRequest.mark_aborted() changes status to aborted."""
        self.assertIsNone(self.work_request.completed_at)

        self.assertTrue(self.work_request.mark_aborted())

        self.work_request.refresh_from_db()
        self.assertEqual(self.work_request.status, WorkRequest.Statuses.ABORTED)
        self.assertLess(self.work_request.completed_at, timezone.now())

    def test_assign_worker(self):
        """Assign Worker to WorkRequest."""
        worker = self.work_request.worker

        self.work_request.worker = None
        self.work_request.save()

        self.work_request.refresh_from_db()

        # Initial status (no worker)
        self.assertIsNone(self.work_request.worker)
        self.assertEqual(self.work_request.status, WorkRequest.Statuses.PENDING)

        channel = self.create_channel(worker.token.key)

        # Assign the worker to the WorkRequest
        self.work_request.assign_worker(worker)

        self.assert_channel_received(channel, {"type": "work_request.assigned"})

        self.work_request.refresh_from_db()

        # Assert final status
        self.assertEqual(self.work_request.worker, worker)
        self.assertEqual(self.work_request.status, WorkRequest.Statuses.PENDING)

    def test_duration(self):
        """duration() returns the correct duration."""
        self.work_request.started_at = datetime(2022, 3, 7, 10, 51)
        self.work_request.completed_at = datetime(2022, 3, 9, 12, 53)
        duration = self.work_request.completed_at - self.work_request.started_at

        self.assertEqual(
            self.work_request.duration,
            duration.total_seconds(),
        )

    def test_duration_is_none_completed_at_is_none(self):
        """duration() returns None because completed_at is None."""
        self.work_request.started_at = datetime(2022, 3, 7, 10, 51)
        self.work_request.completed_at = None

        self.assertIsNone(self.work_request.duration)

    def test_duration_is_none_started_at_is_none(self):
        """duration() returns None because started_at is None."""
        self.work_request.started_at = None
        self.work_request.completed_at = datetime(2022, 3, 7, 10, 51)

        self.assertIsNone(self.work_request.duration)
