#  Copyright 2022 The Debusine developers
#  See the AUTHORS file at the top-level directory of this distribution
#
#  This file is part of Debusine. It is subject to the license terms
#  in the LICENSE file found in the top-level directory of this
#  distribution. No part of Debusine, including this file, may be copied,
#  modified, propagated, or distributed except according to the terms
#  contained in the LICENSE file.
"""Classes and functions supporting database tests."""

import threading

from django.db import connections, transaction


class RunInParallelTransaction:
    """
    Execute a callable in a separate transaction that gets started in a thread.

    The :py:meth:start_transaction() starts a new thread in which it starts
    the transaction and executes the callable. The transaction then stays
    open and the thread waits until we decide to call
    :py:meth:stop_transaction().
    """

    def __init__(self, to_run):
        """
        Create a RunInParallelTransaction object.

        :param to_run: the function/callable object that gets executed in
          a transaction started in a thread.
        """
        self._to_run = to_run
        self._transaction_ready = threading.Event()
        self._transaction_close = threading.Event()
        self._thread = None

    def start_transaction(self):
        """
        Start a transaction and execute the registered callable.

        It does this in a new thread that then waits until
        :py:meth:stop_transaction() gets called.
        """

        def run_in_thread():
            with transaction.atomic():
                self._to_run()
                self._transaction_ready.set()
                self._transaction_close.wait()

                # Close the DB connection that was opened in this thread
                # (otherwise the connection is left opened and the test
                # runner might not be able to destroy the debusine-test DB)
                connections.close_all()

        self._thread = threading.Thread(target=run_in_thread)
        self._thread.start()
        self._transaction_ready.wait()

    def stop_transaction(self):
        """Close the open transaction and let the thread exit."""
        self._transaction_close.set()
        self._thread.join()
