"""
Combined backfill: segmentation + re-classification + embedding upgrade.

Runs each step with per-thread timeout protection and commit-per-thread
to avoid losing progress on failures.

Usage:
    venv/Scripts/python.exe scripts/backfill_all.py [--limit N] [--dry-run] [--skip-seg] [--skip-cls] [--skip-emb]
"""

import argparse
import json
import signal
import sys
import time

sys.path.insert(0, ".")

from sqlalchemy import func
from supporthub.app.db import session_scope
from supporthub.app.models import (
    AddonsMessages, AISuggestions, MessageTopicSegment,
    SegmentationState, Tickets, ThreadEmbeddings,
)


class TimeoutError(Exception):
    pass


def run_with_timeout(fn, timeout_sec=120):
    """Run fn() with a timeout. Windows-compatible via threading."""
    import threading
    result = [None]
    error = [None]

    def worker():
        try:
            result[0] = fn()
        except Exception as e:
            error[0] = e

    t = threading.Thread(target=worker)
    t.start()
    t.join(timeout=timeout_sec)
    if t.is_alive():
        raise TimeoutError(f"Timed out after {timeout_sec}s")
    if error[0]:
        raise error[0]
    return result[0]


def backfill_segmentation(limit=0, dry_run=False):
    """Segment threads that haven't been segmented yet."""
    from supporthub.app.services.segmentation_service import segment_thread

    with session_scope() as session:
        eligible = (
            session.query(
                AddonsMessages.thread_id,
                func.count(AddonsMessages.id).label("cnt"),
            )
            .group_by(AddonsMessages.thread_id)
            .having(func.count(AddonsMessages.id) >= 4)
            .all()
        )
        segmented = {
            r.thread_id
            for r in session.query(SegmentationState.thread_id).all()
        }
        pending = [(r.thread_id, r.cnt) for r in eligible if r.thread_id not in segmented]
        pending.sort(key=lambda x: x[0])

    if limit > 0:
        pending = pending[:limit]

    total = len(pending)
    est_cost = total * 0.0004
    print(f"\n=== SEGMENTATION: {total} threads, ~${est_cost:.2f} ===")

    if dry_run or total == 0:
        return

    done = errors = skipped = 0
    t0 = time.time()

    for i, (tid, cnt) in enumerate(pending, 1):
        try:
            def do_seg():
                with session_scope() as s:
                    result = segment_thread(tid, s, force=False)
                    return result

            result = run_with_timeout(do_seg, timeout_sec=90)
            if result is not None:
                done += 1
            else:
                skipped += 1
        except TimeoutError:
            errors += 1
            print(f"  TIMEOUT thread {tid} ({cnt} msgs)")
        except Exception as exc:
            errors += 1
            print(f"  ERROR thread {tid}: {type(exc).__name__}: {str(exc)[:100]}")

        if i % 25 == 0 or i == total:
            elapsed = time.time() - t0
            rate = i / elapsed if elapsed > 0 else 0
            eta = (total - i) / rate if rate > 0 else 0
            print(f"  [{i}/{total}] done={done} skip={skipped} err={errors} ({rate:.1f}/s, ETA {eta:.0f}s)")

    elapsed = time.time() - t0
    print(f"Segmentation done in {elapsed:.1f}s: {done} ok, {skipped} skip, {errors} err")


def backfill_reclassify(limit=0, dry_run=False):
    """Re-classify AI suggestions missing new fields."""
    from supporthub.app.services.ai_service import AIService

    with session_scope() as session:
        all_sugg = session.query(AISuggestions.id, AISuggestions.classification_json).all()
        pending_ids = []
        for sid, cjson in all_sugg:
            try:
                data = json.loads(cjson)
                if "is_template_candidate" not in data:
                    pending_ids.append(sid)
            except (json.JSONDecodeError, TypeError):
                pending_ids.append(sid)

    if limit > 0:
        pending_ids = pending_ids[:limit]

    total = len(pending_ids)
    est_cost = total * 0.0002
    print(f"\n=== RE-CLASSIFY: {total} suggestions, ~${est_cost:.4f} ===")

    if dry_run or total == 0:
        return

    ai = AIService()
    done = errors = 0
    t0 = time.time()

    for i, sid in enumerate(pending_ids, 1):
        try:
            def do_cls():
                with session_scope() as s:
                    suggestion = s.query(AISuggestions).filter_by(id=sid).one()
                    ticket = s.query(Tickets).filter_by(id=suggestion.ticket_id).one()
                    message = s.query(AddonsMessages).filter_by(id=suggestion.message_id).one()
                    summary, _, cls_json_str, _ = ai._run_classification(ticket, message, s)
                    suggestion.summary = summary[:512]
                    suggestion.classification_json = cls_json_str
                return True

            run_with_timeout(do_cls, timeout_sec=30)
            done += 1
        except TimeoutError:
            errors += 1
            print(f"  TIMEOUT suggestion {sid}")
        except Exception as exc:
            errors += 1
            print(f"  ERROR suggestion {sid}: {type(exc).__name__}: {str(exc)[:100]}")

        if i % 25 == 0 or i == total:
            elapsed = time.time() - t0
            rate = i / elapsed if elapsed > 0 else 0
            eta = (total - i) / rate if rate > 0 else 0
            print(f"  [{i}/{total}] done={done} err={errors} ({rate:.1f}/s, ETA {eta:.0f}s)")

    elapsed = time.time() - t0
    print(f"Re-classify done in {elapsed:.1f}s: {done} ok, {errors} err")


def backfill_embeddings(limit=0, dry_run=False):
    """Upgrade weak/no_reply embeddings and embed missing threads."""
    from supporthub.app.services.embedding_service import EmbeddingService

    svc = EmbeddingService()

    # Find weak embeddings that now have outbound replies
    with session_scope() as session:
        from sqlalchemy import text as sa_text
        weak_ids = [
            row[0] for row in session.execute(sa_text(
                "SELECT DISTINCT e.thread_id FROM thread_embeddings e "
                "JOIN addons_messages m ON m.thread_id = e.thread_id "
                "WHERE e.embedding_quality IN ('no_reply', 'weak') "
                "  AND m.direction = 'outbound'"
            ))
        ]
        missing_ids = [
            row[0] for row in session.execute(sa_text(
                "SELECT t.id FROM addons_threads t "
                "LEFT JOIN thread_embeddings e ON t.id = e.thread_id "
                "WHERE e.id IS NULL"
            ))
        ]

    total_weak = len(weak_ids)
    total_missing = len(missing_ids)
    all_ids = weak_ids + missing_ids

    if limit > 0:
        all_ids = all_ids[:limit]

    est_cost = len(all_ids) * 0.0001
    print(f"\n=== EMBEDDINGS: {total_weak} weak upgrades + {total_missing} missing = {len(all_ids)} total, ~${est_cost:.4f} ===")

    if dry_run or not all_ids:
        return

    done = errors = 0
    t0 = time.time()

    for i, tid in enumerate(all_ids, 1):
        try:
            def do_emb():
                with session_scope() as s:
                    return svc.embed_thread(tid, session=s)

            result = run_with_timeout(do_emb, timeout_sec=60)
            if result:
                done += 1
        except TimeoutError:
            errors += 1
            print(f"  TIMEOUT thread {tid}")
        except Exception as exc:
            errors += 1
            print(f"  ERROR thread {tid}: {type(exc).__name__}: {str(exc)[:100]}")

        if i % 50 == 0 or i == len(all_ids):
            elapsed = time.time() - t0
            rate = i / elapsed if elapsed > 0 else 0
            eta = (len(all_ids) - i) / rate if rate > 0 else 0
            print(f"  [{i}/{len(all_ids)}] done={done} err={errors} ({rate:.1f}/s, ETA {eta:.0f}s)")

    elapsed = time.time() - t0
    print(f"Embeddings done in {elapsed:.1f}s: {done} ok, {errors} err")


def main():
    parser = argparse.ArgumentParser(description="Backfill all: segmentation + re-classify + embeddings")
    parser.add_argument("--limit", type=int, default=0, help="Max items per step (0=all)")
    parser.add_argument("--dry-run", action="store_true", help="Show counts without processing")
    parser.add_argument("--skip-seg", action="store_true", help="Skip segmentation")
    parser.add_argument("--skip-cls", action="store_true", help="Skip re-classification")
    parser.add_argument("--skip-emb", action="store_true", help="Skip embedding upgrades")
    args = parser.parse_args()

    t0 = time.time()

    if not args.skip_seg:
        backfill_segmentation(limit=args.limit, dry_run=args.dry_run)

    if not args.skip_cls:
        backfill_reclassify(limit=args.limit, dry_run=args.dry_run)

    if not args.skip_emb:
        backfill_embeddings(limit=args.limit, dry_run=args.dry_run)

    elapsed = time.time() - t0
    print(f"\n=== ALL DONE in {elapsed:.1f}s ===")


if __name__ == "__main__":
    main()
