# Copyright 2021-2022 The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Tests for the management command of tokens application."""
import datetime
import io
import os
import re
import tempfile
from unittest.mock import patch

from django.core.management import CommandError, call_command
from django.test import TestCase
from django.utils import timezone

from debusine.db.models import Token, WorkRequest, Worker
from debusine.server.management import utils
from debusine.server.management.commands.edit_worker_metadata import (
    WorkerStaticMetadataEditor,
)
from debusine.server.management.commands.remove_tokens import (
    Command as RemoveTokensCommand,
)


def _call_command(*args):
    stdout = io.StringIO()
    stderr = io.StringIO()
    exit_code = "command did not call sys.exit"

    try:
        call_command(*args, stdout=stdout, stderr=stderr)
    except SystemExit as exc:
        exit_code = exc.code

    return stdout.getvalue(), stderr.getvalue(), exit_code


class ListTokensCommandTests(TestCase):
    """Tests for the list_tokens command."""

    def test_list_tokens_no_filtering(self):
        """list_tokens print the token key."""
        token = Token.objects.create(owner='John', comment='Test comment')

        stdout, stderr, _ = _call_command('list_tokens')

        self.assertIn(token.key, stdout)
        self.assertIn(token.owner, stdout)
        self.assertIn(token.created_at.isoformat(), stdout)
        self.assertEqual(stdout.count(str(token.enabled)), 1)
        self.assertIn(token.comment, stdout)

    def test_list_tokens_filtered_by_owner(self):
        """list_tokens print the correct filtered owners."""
        token_john = Token.objects.create(owner='John')
        token_bev = Token.objects.create(owner='Bev')

        stdout, stderr, _ = _call_command('list_tokens', '--owner', 'John')

        self.assertIn(token_john.key, stdout)
        self.assertNotIn(token_bev.key, stdout)


class RemoveTokensCommandTests(TestCase):
    """Tests for the remove_tokens command."""

    def test_remove_tokens(self):
        """remove_tokens deletes the token and prints the deleted key."""
        token = Token.objects.create()

        stdout, stderr, _ = _call_command(
            'remove_tokens', '--yes', '--token', token.key
        )

        self.assertEqual(Token.objects.filter(key=token.key).count(), 0)
        self.assertIn(token.key, stdout)

    def test_remove_tokens_no_tokens(self):
        """remove_tokens returns exit code 1 and prints error message."""
        with self.assertRaises(CommandError) as cm:
            stdout, stderr, _ = _call_command(
                'remove_tokens',
                '--token',
                '9deefd915ecd1009bf7598c1d4acf9a3bbfb8bd9e0c08b4bdc9',
            )

        self.assertEqual(
            cm.exception.args[0], 'There are no tokens to be removed'
        )
        self.assertEqual(cm.exception.returncode, 3)

    def test_remove_tokens_no_tokens_force(self):
        """remove_tokens returns exit code 0 and prints error message."""
        stdout, stderr, _ = _call_command(
            'remove_tokens',
            '--force',
            '--token',
            '9deefd915ecd1009bf7598c1d4acf9a3bbfb8bd9e0c08b4bdc93d099b9a38aa2',
        )

        self.assertEqual('There are no tokens to be removed\n', stdout)

    def test_remove_tokens_confirmation(self):
        """remove_tokens doesn't delete the token (user does not confirm)."""
        token = Token.objects.create()

        cmd = RemoveTokensCommand()
        cmd.input_file = io.StringIO('N\n')

        _call_command(cmd, '--token', token.key)

        self.assertQuerysetEqual(Token.objects.filter(key=token.key), [token])

    def test_remove_tokens_confirmed(self):
        """remove_tokens delete the token (confirmed by the user)."""
        token = Token.objects.create()

        cmd = RemoveTokensCommand()
        cmd.input_file = io.StringIO('y\n')

        _call_command(cmd, '--token', token.key)

        self.assertEqual(Token.objects.filter(key=token.key).count(), 0)


class CreateTokenCommandTest(TestCase):
    """Tests for the create_token command."""

    def test_create_token(self):
        """create_token creates a new token and prints the key on stdout."""
        stdout, stderr, _ = _call_command('create_token', 'james')

        token = Token.objects.first()

        self.assertEqual(stdout, f'{token.key}\n')
        self.assertTrue(token.enabled)


class ListWorkersCommandTests(TestCase):
    """Test for listing workers."""

    def test_list_workers_connected(self):
        """
        List worker command prints worker information.

        The worker is connected.
        """
        token = Token.objects.create()

        worker_1 = Worker.objects.create_with_fqdn('recent-ping', token)
        worker_1.mark_connected()

        stdout, stderr, _ = _call_command('list_workers')

        self.assertIn(worker_1.name, stdout)
        self.assertIn(worker_1.registered_at.isoformat(), stdout)
        self.assertIn(worker_1.connected_at.isoformat(), stdout)
        self.assertIn(token.key, stdout)
        self.assertEqual(stdout.count(str(token.enabled)), 1)
        self.assertIn('Number of workers: 1', stdout)

    def test_list_workers_not_connected(self):
        """
        List worker command prints worker information.

        The worker is not connected.
        """
        token = Token.objects.create()

        worker_1 = Worker.objects.create_with_fqdn('recent-ping', token=token)
        stdout, stderr, _ = _call_command('list_workers')

        registered_at = worker_1.registered_at.isoformat()

        self.assertIn(worker_1.name, stdout)
        self.assertIn(registered_at, stdout)

        # Connected column has a "-"
        self.assertIn(registered_at + '  -', stdout)

        self.assertIn(token.key, stdout)
        self.assertIn('Number of workers: 1', stdout)


class ListWorkRequestsCommandTests(TestCase):
    """Test for listing work requests."""

    def setUp(self):
        """Create a Work Request and a Worker."""
        self.work_request = WorkRequest.objects.create(task_name="sbuild")

    def test_list_work_requests_not_assigned(self):
        """Test a non-assigned work request output."""
        stdout, stderr, _ = _call_command("list_work_requests")

        self.assertIn(self.work_request.status, stdout)

        self.assertIn(
            utils.datetime_to_isoformat(self.work_request.created_at), stdout
        )

    def test_list_work_requests_assigned_finished(self):
        """Test an assigned work request output."""
        worker = Worker.objects.create_with_fqdn(
            "neptune", Token.objects.create()
        )
        self.work_request.worker = worker
        self.work_request.created_at = timezone.now()
        one_sec = datetime.timedelta(seconds=1)
        self.work_request.started_at = self.work_request.created_at + one_sec
        self.work_request.completed_at = self.work_request.started_at + one_sec
        self.work_request.result = WorkRequest.Results.SUCCESS
        self.work_request.save()

        stdout, stderr, _ = _call_command("list_work_requests")

        self.assertIn("neptune", stdout)
        self.assertIn(
            utils.datetime_to_isoformat(self.work_request.created_at), stdout
        )
        self.assertIn(
            utils.datetime_to_isoformat(self.work_request.started_at), stdout
        )
        self.assertIn(
            utils.datetime_to_isoformat(self.work_request.completed_at), stdout
        )
        self.assertIn(self.work_request.result, stdout)

        self.assertIn('Number of work requests: 1', stdout)


class EditWorkerMetadataTests(TestCase):
    """Test for edit worker metadata command."""

    def setUp(self):
        """Create a default Token and Worker."""
        self.token = Token.objects.create()
        self.worker = Worker.objects.create_with_fqdn(
            "worker-01.lan", self.token
        )

    def test_set_worker_metadata(self):
        """Set the worker metadata from the YAML file."""
        with tempfile.NamedTemporaryFile(suffix='.yml') as static_metadata_file:
            static_metadata_file.write(
                b'sbuild:\n  architectures:\n    - amd64\n    - i386\n'
                b'  distributions:\n    - jessie\n    - stretch'
            )
            static_metadata_file.seek(0)

            _call_command(
                'edit_worker_metadata',
                '--set',
                static_metadata_file.name,
                self.worker.name,
            )

            self.worker.refresh_from_db()

            self.assertEqual(
                self.worker.static_metadata,
                {
                    'sbuild': {
                        'architectures': ['amd64', 'i386'],
                        'distributions': ['jessie', 'stretch'],
                    }
                },
            )

    def test_set_worker_metadata_worker_does_not_exist(self):
        """Error message in stderr if the worker name does not exist."""
        worker_name = 'name-of-the-worker-does-not-exist.lan'

        stdout, stderr, exit_code = _call_command(
            'edit_worker_metadata',
            '--set',
            '/tmp/some_file_not_accessed_because_worker_does_not_exist.yaml',
            worker_name,
        )

        self.assertEqual(
            stderr,
            f"Error: worker '{worker_name}' is not registered\n"
            f"Use the command 'list_workers' to list the "
            f"existing workers\n",
        )
        self.assertEqual(exit_code, 3)

    def test_set_worker_metadata_file_not_found(self):
        """Error message stderr if the YAML file is not found."""
        file_path = '/tmp/does_not_exist_file.yaml'

        stdout, stderr, exit_code = _call_command(
            'edit_worker_metadata', '--set', file_path, self.worker.name
        )

        self.assertEqual(
            stderr,
            f"Error: cannot open worker configuration "
            f"file '{file_path}': [Errno 2] No such file or "
            f"directory: '/tmp/does_not_exist_file.yaml'\n",
        )
        self.assertEqual(exit_code, 3)

    def test_set_worker_metadata_cannot_open_file(self):
        """Error message stderr if the YAML file cannot be opened."""
        with tempfile.TemporaryDirectory() as temp_directory:
            stdout, stderr, exit_code = _call_command(
                'edit_worker_metadata',
                '--set',
                temp_directory,
                self.worker.name,
            )

        self.assertEqual(
            stderr,
            f"Error: cannot open worker configuration file '{temp_directory}': "
            f"[Errno 21] Is a directory: '{temp_directory}'\n",
        )
        self.assertEqual(exit_code, 3)

    def test_set_worker_metadata_file_is_invalid(self):
        """Error message stderr if the YAML file is not found."""
        with tempfile.NamedTemporaryFile(suffix='.yml') as static_metadata_file:
            static_metadata_file.write(b'a\nb:')
            static_metadata_file.seek(0)

            stdout, stderr, exit_code = _call_command(
                'edit_worker_metadata',
                '--set',
                static_metadata_file.name,
                self.worker.name,
            )

        self.assertRegex(
            stderr, 'Invalid YAML: mapping values are not allowed here.*'
        )
        self.assertEqual(exit_code, 3)

    def test_edit_interactive_worker_valid_yaml(self):
        """Editor changes the contents of Worker's YAML metadata."""
        stdout = io.StringIO()
        stderr = io.StringIO()

        worker_editor = WorkerStaticMetadataEditor(
            self.worker, yaml_file=None, stdout=stdout, stderr=stderr
        )

        def _set_valid_file_contents(file_path):
            with open(file_path, 'w') as file:
                file.write('distributions:\n - buster\n - bullseye')

        with patch.object(
            worker_editor, '_open_editor', new=_set_valid_file_contents
        ):
            worker_editor.edit()

        self.worker.refresh_from_db()

        self.assertEqual(
            self.worker.static_metadata,
            {'distributions': ['buster', 'bullseye']},
        )

        self.assertEqual(
            stdout.getvalue(), 'debusine: metadata set for worker-01-lan'
        )
        self.assertEqual(stderr.getvalue(), '')

    def test_edit_worker_metadata_not_dictionary(self):
        """Return error when the metadata file is a string."""
        with tempfile.NamedTemporaryFile(suffix='.yml') as static_metadata_file:
            static_metadata_file.write(b'bullseye')

            static_metadata_file.seek(0)

            stdout, stderr, exit_code = _call_command(
                'edit_worker_metadata',
                '--set',
                static_metadata_file.name,
                self.worker.name,
            )

            self.assertEqual(
                stderr, 'Worker metadata must be a dictionary or empty\n'
            )
            self.assertEqual(exit_code, 3)

    def test_edit_interactive_worker_invalid_yaml(self):
        """
        Editor changes the contents of Worker's static data to invalid YAML.

        Then it aborts the editing.
        """
        stdout = io.StringIO()
        stderr = io.StringIO()

        self.worker.static_metadata = {'distributions': ['potato']}
        self.worker.save()

        worker_editor = WorkerStaticMetadataEditor(
            self.worker, yaml_file=None, stdout=stdout, stderr=stderr
        )

        def _set_invalid_file_contents(file_path):
            with open(file_path, 'w') as file:
                file.write('"')

        with patch.object(
            worker_editor, '_open_editor', new=_set_invalid_file_contents
        ), patch.object(worker_editor, '_input', return_value='n'):
            worker_editor.edit()

        self.worker.refresh_from_db()

        self.assertEqual(
            self.worker.static_metadata, {'distributions': ['potato']}
        )

        stdout = stdout.getvalue()
        stderr = stderr.getvalue()

        self.assertRegex(
            stdout,
            r'Do you want to retry the same edit\? \(y/n\)\n'
            r'debusine: edits left in (.+?)\.yaml',
        )

        self.assertRegex(
            stderr, 'Invalid YAML: while scanning a quoted scalar.*'
        )

        # Check contents of the file
        m = re.search(r'debusine: edits left in (.+?\.yaml)$', stdout)
        edits_left_file_path = m.group(1)

        self.assertTrue(os.path.exists(edits_left_file_path))
        with open(edits_left_file_path) as fd:
            self.assertEqual(fd.read(), '"')

        os.remove(edits_left_file_path)

    def test_edit_interactive_worker_empty_yaml(self):
        """Editor sets the contents of Worker's YAML to an empty file."""
        stdout = io.StringIO()
        stderr = io.StringIO()

        worker_editor = WorkerStaticMetadataEditor(
            self.worker, yaml_file=None, stdout=stdout, stderr=stderr
        )

        def _set_empty_yaml(file_path):
            with open(file_path, 'w') as file:
                file.write('')

        def _assert_not_called():
            self.fail("should not have been called")  # pragma: no cover

        with patch.object(
            worker_editor, '_open_editor', new=_set_empty_yaml
        ), patch.object(worker_editor, '_input', new=_assert_not_called):
            worker_editor.edit()

        self.worker.refresh_from_db()
        self.assertEqual(self.worker.static_metadata, {})

    def test_edit_interactive_worker_invalid_then_valid_yaml(self):
        """
        Editor changes the contents of Worker's static data to invalid YAML.

        Then it fixes the YAML file.
        """
        self.worker.static_metadata = {'distributions': ['potato']}
        self.worker.save()

        stdout = io.StringIO()
        stderr = io.StringIO()

        worker_editor = WorkerStaticMetadataEditor(
            self.worker, yaml_file=None, stdout=stdout, stderr=stderr
        )

        def _verify_and_edit(file_path):
            """
            Write test YAML into file_path.

            The first that it is called it writes invalid YAML,
            the second time writes correct YAML.
            """
            execution_count = (
                getattr(_verify_and_edit, '_execution_count', 0) + 1
            )

            if execution_count == 1:
                setattr(_verify_and_edit, '_executed', True)
                with open(file_path) as file:
                    self.assertEqual(file.read(), 'distributions:\n- potato\n')
                contents = '&'  # Invalid YAML
            else:
                with open(file_path) as file:
                    self.assertEqual(file.read(), '&')

                contents = 'distributions:\n - bookworm'

            with open(file_path, 'w') as file:
                file.write(contents)

            setattr(_verify_and_edit, '_execution_count', execution_count)

        with patch.object(
            worker_editor, '_open_editor', new=_verify_and_edit
        ), patch.object(worker_editor, '_input', return_value='y'):
            worker_editor.edit()

        self.worker.refresh_from_db()

        self.assertEqual(
            self.worker.static_metadata, {'distributions': ['bookworm']}
        )

        stdout = stdout.getvalue()
        stderr = stderr.getvalue()

        self.assertRegex(stderr, 'Invalid YAML: while scanning an anchor.*')

        self.assertEqual(
            'Do you want to retry the same edit? (y/n)\n'
            'debusine: metadata set for worker-01-lan',
            stdout,
        )


class ManageTokenTests(TestCase):
    """Tests for manage_token command."""

    def setUp(self):
        """Create a default Token."""
        self.token = Token.objects.create()

    def test_enable_token(self):
        """'manage_token enable <token>' enables the token."""
        self.assertFalse(self.token.enabled)

        _call_command('manage_token', 'enable', self.token.key)
        self.token.refresh_from_db()
        self.assertTrue(self.token.enabled)

    def test_disable_token(self):
        """'manage_token disable <token>' disables the token."""
        self.token.enable()

        self.assertTrue(self.token.enabled)

        _call_command('manage_token', 'disable', self.token.key)
        self.token.refresh_from_db()
        self.assertFalse(self.token.enabled)

    def test_enable_token_not_found(self):
        """Token not found raises CommandError."""
        with self.assertRaises(CommandError) as cm:
            _call_command('manage_token', 'enable', 'token-key-does-not-exist')

        self.assertEqual(cm.exception.args[0], 'Token not found')


class ManageWorkerTests(TestCase):
    """Tests for manage_worker command."""

    def setUp(self):
        """Set up a default token and worker."""
        token = Token.objects.create()
        self.worker = Worker.objects.create(
            name='worker-a', token=token, registered_at=timezone.now()
        )

    def test_enable_worker(self):
        """'manage_worker enable <worker> enables the worker."""
        self.assertFalse(self.worker.token.enabled)

        _call_command('manage_worker', 'enable', self.worker.name)

        self.worker.token.refresh_from_db()
        self.assertTrue(self.worker.token.enabled)

    def test_disable_worker(self):
        """manage_worker disable <worker> disables the worker."""
        self.worker.token.enable()
        self.worker.token.refresh_from_db()
        self.assertTrue(self.worker.token.enabled)

        _call_command('manage_worker', 'disable', self.worker.name)

        # Worker is disabled
        self.worker.token.refresh_from_db()
        self.assertFalse(self.worker.token.enabled)

    def test_enable_worker_not_found(self):
        """Worker not found raises CommandError."""
        with self.assertRaises(CommandError) as cm:
            _call_command('manage_worker', 'enable', 'worker-does-not-exist')

        self.assertEqual(cm.exception.args[0], 'Worker not found')
