"""One-time migration: copy all data from SQLite (local dev) to PostgreSQL tenant schema.

Reads every table listed in TABLES_IN_ORDER from the SQLite source and bulk-inserts
into the target PostgreSQL tenant schema using ON CONFLICT DO NOTHING batches.

Usage:
    venv/bin/python scripts/migrate_sqlite_to_postgres.py [--sqlite PATH] [--schema SCHEMA] [--dry-run]

Defaults:
    --sqlite  supporthub_local.db   (relative to cwd / project root)
    --schema  tenant_internal
    --dry-run off

The PostgreSQL connection is read from the app config (supporthub/app/config.py).
"""

from __future__ import annotations

import argparse
import os
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from sqlalchemy import create_engine, text, MetaData, Table

# Tables to migrate, in FK-safe order.
# 'users_admin' and 'alembic_version' are intentionally excluded.
TABLES_IN_ORDER = [
    "addons_products",
    "addons_customers",
    "addons_orders",
    "addons_threads",
    "addons_messages",
    "addons_attachments",
    "tickets",
    "tags",
    "ticket_tags",
    "predefined_messages",
    "product_predefined_messages",
    "app_settings",
    "doc_sources",
    "doc_pages",
    "thread_embeddings",
    "product_knowledge",
    "ai_suggestions",
    "ai_suggestion_feedback",
    "message_translations",
    "message_topic_segments",
    "segmentation_state",
    "generated_guides",
    "website_checks",
    "customer_sites",
    "site_credentials",
    "sync_history",
    "audit_log",
    "url_safety_checks",
]

# For these tables, force specific columns to NULL regardless of source value.
# audit_log.admin_user_id → FK target (users_admin) no longer exists in PG.
FORCE_NULL: dict[str, set[str]] = {
    "audit_log": {"admin_user_id"},
}

BATCH_SIZE = 500

# Key tables to verify after migration
VERIFY_TABLES = ["tickets", "addons_messages", "addons_threads"]


def get_pg_columns(dst_conn, schema: str, table_name: str) -> dict[str, int | None]:
    """Return {column_name: max_length_or_None} for all columns in the PG target table.

    max_length is set for VARCHAR columns; None for all other types.
    """
    result = dst_conn.execute(
        text(
            """
            SELECT column_name, character_maximum_length
            FROM information_schema.columns
            WHERE table_schema = :schema
              AND table_name = :table
            ORDER BY ordinal_position
            """
        ),
        {"schema": schema, "table": table_name},
    )
    return {row[0]: row[1] for row in result}


def migrate_table(
    src_conn,
    dst_conn,
    table_name: str,
    schema: str,
    dry_run: bool,
) -> int:
    """Copy all rows from src SQLite table into dst PG schema.table.

    Returns the number of rows read from source (or 0 if table missing).
    """
    # Reflect source table
    src_meta = MetaData()
    try:
        src_table = Table(table_name, src_meta, autoload_with=src_conn.engine)
    except Exception as exc:
        print(f"  [{table_name}] SKIP (not in SQLite): {exc}")
        return 0

    src_columns_all = [c.name for c in src_table.columns]

    # Identify boolean columns by SQLite declared type so we can coerce 0/1 → bool.
    # SQLAlchemy reflects SQLite BOOLEAN as Boolean type.
    from sqlalchemy import Boolean
    bool_cols: set[str] = {
        c.name for c in src_table.columns if isinstance(c.type, Boolean)
    }

    rows = src_conn.execute(text(f'SELECT * FROM "{table_name}"')).fetchall()
    total_rows = len(rows)

    if total_rows == 0:
        print(f"  [{table_name}] 0 rows — nothing to do")
        return 0

    if dry_run:
        print(f"  [{table_name}] {total_rows} rows — skipped (dry-run)")
        return total_rows

    # Get PG column info: names (to filter out SQLite-only cols) + varchar limits
    pg_col_info = get_pg_columns(dst_conn, schema, table_name)
    pg_col_names: set[str] = set(pg_col_info.keys())
    varchar_limits: dict[str, int] = {
        k: v for k, v in pg_col_info.items() if v is not None
    }

    # Only insert columns that exist in BOTH source and target (handles schema divergence)
    columns = [c for c in src_columns_all if c in pg_col_names]
    dropped = set(src_columns_all) - pg_col_names
    if dropped:
        print(f"    (skipping source-only columns: {sorted(dropped)})")

    # Build insert SQL with qualified table name
    col_list = ", ".join(f'"{c}"' for c in columns)
    placeholders = ", ".join(f":{c}" for c in columns)
    insert_sql = text(
        f'INSERT INTO {schema}."{table_name}" ({col_list}) VALUES ({placeholders}) ON CONFLICT DO NOTHING'
    )

    force_null_cols = FORCE_NULL.get(table_name, set())

    inserted = 0
    for i in range(0, total_rows, BATCH_SIZE):
        batch_raw = rows[i : i + BATCH_SIZE]
        batch = []
        for row in batch_raw:
            # Build record from ALL source columns, then filter to intersection
            full_record = dict(zip(src_columns_all, row))
            record = {k: full_record[k] for k in columns}
            # Force specific columns to NULL (e.g. dead FK references)
            for col in force_null_cols:
                if col in record:
                    record[col] = None
            # Coerce SQLite integer booleans (0/1) → Python bool for PG BOOLEAN columns.
            for col in bool_cols:
                if col in record and record[col] is not None:
                    record[col] = bool(record[col])
            # Truncate strings that exceed PG varchar limits (SQLite enforces no length).
            for col, max_len in varchar_limits.items():
                if col in record and isinstance(record[col], str) and len(record[col]) > max_len:
                    record[col] = record[col][:max_len]
            # psycopg2 handles bytes → bytea natively; None passes as NULL — no transforms needed
            batch.append(record)

        dst_conn.execute(insert_sql, batch)
        dst_conn.commit()
        inserted += len(batch)

    print(f"  [{table_name}] {inserted}/{total_rows} rows — done")
    return total_rows


def reset_sequences(dst_conn, schema: str) -> None:
    """After bulk-insert, advance each serial sequence to MAX(id)+1 so future inserts don't collide."""
    seqfix = """
    DO $$
    DECLARE
        r RECORD;
    BEGIN
        FOR r IN
            SELECT table_name, column_name,
                   pg_get_serial_sequence(
                       format('%s.%s', :schema, table_name), column_name
                   ) AS seq
            FROM information_schema.columns
            WHERE table_schema = :schema
              AND column_default LIKE 'nextval%'
        LOOP
            IF r.seq IS NOT NULL THEN
                EXECUTE format(
                    'SELECT setval(%L, COALESCE((SELECT MAX(%I) FROM %I.%I), 0) + 1, false)',
                    r.seq, r.column_name, :schema, r.table_name
                );
            END IF;
        END LOOP;
    END $$;
    """
    dst_conn.execute(text(seqfix), {"schema": schema})
    dst_conn.commit()
    print("  Sequences reset.")


def verify_counts(dst_conn, schema: str) -> None:
    """Print PG row counts for key tables."""
    print("\n--- Verification: PostgreSQL row counts ---")
    for table in VERIFY_TABLES:
        row = dst_conn.execute(
            text(f'SELECT COUNT(*) FROM {schema}."{table}"')
        ).fetchone()
        print(f"  {table}: {row[0]} rows")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Migrate SQLite data to PostgreSQL tenant schema"
    )
    parser.add_argument(
        "--sqlite",
        default="supporthub_local.db",
        help="Path to SQLite database file (default: supporthub_local.db)",
    )
    parser.add_argument(
        "--schema",
        default="tenant_internal",
        help="Target PostgreSQL schema name (default: tenant_internal)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Print row counts without inserting anything",
    )
    args = parser.parse_args()

    sqlite_path = os.path.abspath(args.sqlite)
    if not os.path.exists(sqlite_path):
        print(f"ERROR: SQLite file not found: {sqlite_path}")
        sys.exit(1)

    from supporthub.app.config import Config

    pg_url = Config.database_url()
    sqlite_url = f"sqlite:///{sqlite_path}"

    print(f"Source : {sqlite_url}")
    print(f"Target : {pg_url.split('@')[-1]} (schema: {args.schema})")
    if args.dry_run:
        print("Mode   : DRY-RUN (no data will be written)\n")
    else:
        print("Mode   : LIVE (data will be inserted)\n")

    sqlite_engine = create_engine(sqlite_url)
    pg_engine = create_engine(pg_url, pool_pre_ping=True)

    # Set search_path so unqualified references resolve to the tenant schema
    with pg_engine.connect() as pg_conn:
        pg_conn.execute(text(f"SET search_path TO {args.schema}, public"))
        pg_conn.commit()

    total_rows = 0
    with sqlite_engine.connect() as src_conn, pg_engine.connect() as dst_conn:
        # Set search_path on this connection too
        dst_conn.execute(text(f"SET search_path TO {args.schema}, public"))
        dst_conn.commit()

        for table in TABLES_IN_ORDER:
            total_rows += migrate_table(
                src_conn, dst_conn, table, args.schema, args.dry_run
            )

        if not args.dry_run:
            print("\nResetting PostgreSQL sequences ...")
            reset_sequences(dst_conn, args.schema)
            verify_counts(dst_conn, args.schema)

    print(f"\nDone. {total_rows} total source rows processed.")


if __name__ == "__main__":
    main()
