"""
Celery tasks for AI operations per tenant.

Tasks receive schema explicitly since there is no Flask g in worker context.
"""

from __future__ import annotations

import logging

from ..celery_app import celery_app
from ..db import session_scope

logger = logging.getLogger(__name__)


@celery_app.task(bind=True, max_retries=2, default_retry_delay=30)
def generate_suggestion(self, ticket_id: int, schema: str, org_id: int = None) -> dict:
    """
    Generate an AI suggestion for a ticket (classify + draft).

    Replaces the threading.Thread approach used in the single-tenant app.
    """
    logger.info("[generate_suggestion] ticket_id=%s schema=%s", ticket_id, schema)
    try:
        with session_scope(schema=schema) as session:
            from ..models import Tickets, AddonsMessages
            from ..services.ai_service import AIService

            ticket = session.query(Tickets).filter_by(id=ticket_id).one_or_none()
            if not ticket:
                return {"status": "error", "detail": "ticket not found"}

            # Get most recent inbound message
            message = (
                session.query(AddonsMessages)
                .filter_by(thread_id=ticket.thread_id, direction="inbound")
                .order_by(AddonsMessages.created_at.desc())
                .first()
            )
            if not message:
                return {"status": "skip", "detail": "no inbound message"}

            ai = AIService(org=None)  # org=None uses platform key; pass org for BYOK
            suggestion = ai.generate_suggestion(ticket, message, session=session)
            return {"status": "ok", "suggestion_id": suggestion.id}

    except Exception as exc:
        logger.error("[generate_suggestion] Failed for ticket %s: %s", ticket_id, exc)
        raise self.retry(exc=exc)


@celery_app.task(bind=True, max_retries=2, default_retry_delay=60)
def generate_embeddings(self, thread_id: int, schema: str) -> dict:
    """
    Generate or refresh embeddings for a thread.

    Replaces the threading.Thread re-embed calls in main.py.
    """
    logger.info("[generate_embeddings] thread_id=%s schema=%s", thread_id, schema)
    try:
        with session_scope(schema=schema) as session:
            from ..services.embedding_service import EmbeddingService
            service = EmbeddingService()
            result = service.embed_thread(thread_id, session=session)
            return {"status": "ok", "result": result}
    except Exception as exc:
        logger.error("[generate_embeddings] Failed for thread %s: %s", thread_id, exc)
        raise self.retry(exc=exc)
