"""
Backfill segmentation for all threads with 4+ messages that haven't been
segmented yet.

Usage:
    venv/Scripts/python.exe scripts/backfill_segmentation.py [--limit N] [--force] [--dry-run]

Options:
    --limit N   Process at most N threads (default: all)
    --force     Re-segment even if already done (recompute all)
    --dry-run   Show what would be processed without doing it

Cost: ~$0.0004/thread (embedding + GPT-4o-mini labeling)
"""

import argparse
import sys
import time

sys.path.insert(0, ".")

from sqlalchemy import func

from supporthub.app.db import session_scope
from supporthub.app.models import AddonsMessages, SegmentationState
from supporthub.app.services.segmentation_service import segment_thread


def get_pending_threads(session, force: bool):
    """Return list of (thread_id, message_count) needing segmentation."""
    eligible = (
        session.query(
            AddonsMessages.thread_id,
            func.count(AddonsMessages.id).label("cnt"),
        )
        .group_by(AddonsMessages.thread_id)
        .having(func.count(AddonsMessages.id) >= 4)
        .all()
    )

    if force:
        return [(r.thread_id, r.cnt) for r in eligible]

    segmented = {
        r.thread_id
        for r in session.query(SegmentationState.thread_id).all()
    }
    return [(r.thread_id, r.cnt) for r in eligible if r.thread_id not in segmented]


def main():
    parser = argparse.ArgumentParser(description="Backfill thread segmentation")
    parser.add_argument("--limit", type=int, default=0, help="Max threads to process (0=all)")
    parser.add_argument("--force", action="store_true", help="Re-segment all, even if already done")
    parser.add_argument("--dry-run", action="store_true", help="Show pending threads without processing")
    args = parser.parse_args()

    with session_scope() as session:
        pending = get_pending_threads(session, args.force)
        pending.sort(key=lambda x: x[0])  # process in order

        if args.limit > 0:
            pending = pending[: args.limit]

        total = len(pending)
        total_messages = sum(cnt for _, cnt in pending)
        est_cost = total * 0.0004

        print(f"Threads to process: {total}")
        print(f"Total messages: {total_messages}")
        print(f"Estimated cost: ${est_cost:.2f}")

        if args.dry_run:
            for tid, cnt in pending[:20]:
                print(f"  thread {tid}: {cnt} messages")
            if total > 20:
                print(f"  ... and {total - 20} more")
            return

        if total == 0:
            print("Nothing to do.")
            return

        done = 0
        skipped = 0
        errors = 0
        t0 = time.time()

        for i, (tid, cnt) in enumerate(pending, 1):
            try:
                result = segment_thread(tid, session, force=args.force)
                if result is not None:
                    done += 1
                else:
                    skipped += 1
            except Exception as exc:
                errors += 1
                print(f"  ERROR thread {tid}: {exc}")

            # Progress every 25 threads
            if i % 25 == 0 or i == total:
                elapsed = time.time() - t0
                rate = i / elapsed if elapsed > 0 else 0
                eta = (total - i) / rate if rate > 0 else 0
                print(
                    f"  [{i}/{total}] done={done} skipped={skipped} errors={errors} "
                    f"({rate:.1f} threads/s, ETA {eta:.0f}s)"
                )

            # Flush every 50 to avoid huge transactions
            if i % 50 == 0:
                session.flush()

        elapsed = time.time() - t0
        print(f"\nDone in {elapsed:.1f}s: {done} segmented, {skipped} skipped, {errors} errors")


if __name__ == "__main__":
    main()
