# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Celery integration for debusine tasks."""

import logging
from contextlib import ExitStack

import pgtransaction
from celery import shared_task
from django.db import transaction

from debusine.artifacts.models import TaskTypes
from debusine.db.context import context
from debusine.db.models import WorkRequest
from debusine.db.models.work_requests import (
    compute_workflow_last_activity,
    compute_workflow_runtime_status,
    workflow_ancestors,
)
from debusine.server.tasks import BaseServerTask
from debusine.tasks import TaskConfigError
from debusine.tasks.models import OutputData, OutputDataError

logger = logging.getLogger(__name__)


class WorkRequestNotPending(Exception):
    """We only run pending work requests."""


class ServerTaskRunError(Exception):
    """Running a server task failed."""

    def __init__(self, message: str, code: str) -> None:
        """Construct the exception."""
        self.message = message
        self.code = code


def _run_server_task_or_error(work_request: WorkRequest) -> bool:
    """Run a server task, raising :py:class:`ServerTaskRunError` on errors."""
    if work_request.task_type != TaskTypes.SERVER:
        raise ServerTaskRunError(
            "Cannot run on a Celery worker", "wrong-task-type"
        )

    task_name = work_request.task_name
    try:
        task = work_request.get_task()
    except ValueError as exc:
        raise ServerTaskRunError(str(exc), "setup-failed") from exc
    except TaskConfigError as exc:
        raise ServerTaskRunError(
            f"Failed to configure: {exc}", "configure-failed"
        ) from exc
    assert isinstance(task, BaseServerTask)

    try:
        with ExitStack() as stack:
            if not task.TASK_MANAGES_TRANSACTIONS:
                stack.enter_context(transaction.atomic())
            result = task.execute_logging_exceptions()
    except Exception as exc:
        raise ServerTaskRunError(
            f"Execution failed: {exc}", "execute-failed"
        ) from exc
    else:
        if task.aborted:
            logger.info("Task: %s has been aborted", task_name)
            # No need to update DB state
            return False
        else:
            # Lock the whole workflow and mark the work request as
            # completed.
            with transaction.atomic():
                WorkRequest.objects.filter(
                    id=work_request.id
                ).lock_workflows_for_update()
                work_request.mark_completed(
                    WorkRequest.Results.SUCCESS
                    if result
                    else WorkRequest.Results.FAILURE
                )
                return result


# mypy complains that celery.shared_task is untyped, which is true, but we
# can't fix that here.
@shared_task  # type: ignore[misc]
# Unlike some other Celery tasks, this must _not_ be decorated with
# @transaction.atomic, because it needs to do finer-grained transaction
# control.
def run_server_task(work_request_id: int) -> bool:
    """
    Run a :py:class:`BaseServerTask` via Celery.

    :param work_request_id: The ID of the work request to run.
    :raises WorkRequest.DoesNotExist: if ``work_request_id`` does not exist.
    :raises WorkRequestNotPending: if the given work request's status is not
      "pending".
    :return: True if the server task succeeded; False if it failed or raised
      an exception.
    """
    try:
        work_request = WorkRequest.objects.get(pk=work_request_id)
    except WorkRequest.DoesNotExist:
        logger.error("Work request %d does not exist", work_request_id)
        raise

    if work_request.status != WorkRequest.Statuses.PENDING:
        logger.error(
            "Work request %d is in status %s, not pending",
            work_request_id,
            work_request.status,
        )
        raise WorkRequestNotPending

    # We would normally lock the whole workflow as close to the start of the
    # operation as possible, holding that lock until the end of the
    # transaction to minimize the risk of deadlocks when updating multiple
    # work requests in succession.  However, server tasks may take a long
    # time to run, and some of them manage their own transactions.
    #
    # WorkRequest.mark_running has no cascading effects on other work
    # requests, and at this point we're in Django's default autocommit mode,
    # so the worst case at this point is that it fails to obtain a lock on
    # this work request; it cannot cause a deadlock.
    work_request.mark_running()

    context.reset()
    work_request.set_current()

    try:
        return _run_server_task_or_error(work_request)
    except ServerTaskRunError as exc:
        # Lock the whole workflow and record the error.
        with transaction.atomic():
            logger.error(
                "Error running work request %s/%s (%s): %s",
                work_request.task_type,
                work_request.task_name,
                work_request.id,
                exc.message,
            )
            WorkRequest.objects.filter(
                id=work_request.id
            ).lock_workflows_for_update()
            work_request.mark_completed(
                WorkRequest.Results.ERROR,
                output_data=OutputData(
                    errors=[OutputDataError(message=exc.message, code=exc.code)]
                ),
            )
            return False


# mypy complains that celery.shared_task is untyped, which is true, but we
# can't fix that here.
@shared_task  # type: ignore[misc]
# We may encounter serialization failures, but if so it's probably OK to
# just let the next run of this task catch up.
@pgtransaction.atomic(isolation_level=pgtransaction.REPEATABLE_READ)
def update_workflows() -> None:
    """
    Update expensive properties of workflows.

    The ``workflow_last_activity_at`` and ``workflow_runtime_status`` fields
    of workflows are used in the web UI and need to be kept up to date.
    """
    need_update = WorkRequest.objects.filter(
        workflows_need_update=True
    ).lock_workflows_for_update()
    for workflow in workflow_ancestors(need_update):
        assert workflow.task_type == TaskTypes.WORKFLOW
        workflow.workflow_last_activity_at = compute_workflow_last_activity(
            workflow
        )
        workflow.workflow_runtime_status = compute_workflow_runtime_status(
            workflow
        )
        workflow.save()
    need_update.update(workflows_need_update=False)
