"""One-time migration: copy all data from MySQL (single-tenant) to PostgreSQL tenant schema.

This script reads all tables from the existing MySQL database and inserts the data
into the first tenant's PostgreSQL schema. Run AFTER the Postgres DB is initialized
and the first organization + schema have been provisioned.

Usage:
    MYSQL_URL="mysql+pymysql://user:pass@host/dbname" \
    venv/bin/python scripts/migrate_mysql_to_postgres.py --org-id 1

Prerequisites:
    pip install pymysql  (temporarily, just for this script)
    The target org's schema must already exist (run run_tenant_migrations.py first).
"""
import argparse
import sys
import os
import datetime

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from sqlalchemy import create_engine, text, MetaData, Table
from sqlalchemy.orm import Session

TABLES_IN_ORDER = [
    "addons_products",
    "addons_customers",
    "addons_orders",
    "addons_threads",
    "addons_messages",
    "message_attachments",
    "addons_thread_tags",
    "app_settings",
    "reply_templates",
    "doc_sources",
    "doc_chunks",
    "thread_embeddings",
    "credentials",
    "audit_log",
    "sync_history",
]

TABLES_WITH_JSON = {
    "addons_threads": ["tags"],
    "app_settings": [],
}


def migrate_table(
    src_conn,
    dst_conn,
    table_name: str,
    schema: str,
    batch_size: int = 500,
) -> int:
    src_meta = MetaData()
    try:
        src_table = Table(table_name, src_meta, autoload_with=src_conn.engine)
    except Exception as exc:
        print(f"  Skipping {table_name}: {exc}")
        return 0

    columns = [c.name for c in src_table.columns]
    rows = src_conn.execute(text(f"SELECT * FROM `{table_name}`")).fetchall()
    if not rows:
        print(f"  {table_name}: 0 rows (empty)")
        return 0

    dst_table = f"{schema}.{table_name}"
    col_list = ", ".join(f'"{c}"' for c in columns)
    placeholders = ", ".join(f":{c}" for c in columns)
    insert_sql = text(f"INSERT INTO {dst_table} ({col_list}) VALUES ({placeholders}) ON CONFLICT DO NOTHING")

    inserted = 0
    for i in range(0, len(rows), batch_size):
        batch = [dict(zip(columns, row)) for row in rows[i : i + batch_size]]
        for record in batch:
            for k, v in record.items():
                if isinstance(v, datetime.datetime):
                    record[k] = v
                elif v is None:
                    record[k] = None
        dst_conn.execute(insert_sql, batch)
        dst_conn.commit()
        inserted += len(batch)

    print(f"  {table_name}: {inserted} rows migrated")
    return inserted


def reset_sequences(dst_conn, schema: str) -> None:
    """Reset PostgreSQL sequences after bulk insert (so auto-increment works correctly)."""
    seqfix = """
    DO $$
    DECLARE
        r RECORD;
    BEGIN
        FOR r IN
            SELECT table_name, column_name, pg_get_serial_sequence(
                format('%s.%s', :schema, table_name), column_name
            ) AS seq
            FROM information_schema.columns
            WHERE table_schema = :schema
              AND column_default LIKE 'nextval%'
        LOOP
            IF r.seq IS NOT NULL THEN
                EXECUTE format(
                    'SELECT setval(%L, COALESCE((SELECT MAX(%I) FROM %I.%I), 0) + 1, false)',
                    r.seq, r.column_name, :schema, r.table_name
                );
            END IF;
        END LOOP;
    END $$;
    """
    dst_conn.execute(text(seqfix), {"schema": schema})
    dst_conn.commit()
    print("  Sequences reset.")


def main() -> None:
    parser = argparse.ArgumentParser(description="Migrate MySQL data to PostgreSQL tenant schema")
    parser.add_argument("--org-id", type=int, default=1, help="Target org ID in public.organizations")
    parser.add_argument("--batch-size", type=int, default=500)
    args = parser.parse_args()

    mysql_url = os.environ.get("MYSQL_URL")
    if not mysql_url:
        print("ERROR: Set MYSQL_URL environment variable (mysql+pymysql://user:pass@host/db)")
        sys.exit(1)

    from supporthub.app.config import Config
    pg_url = Config.database_url()

    mysql_engine = create_engine(mysql_url, pool_pre_ping=True)
    pg_engine = create_engine(pg_url, pool_pre_ping=True)

    # Fetch target schema name
    with pg_engine.connect() as pg_conn:
        row = pg_conn.execute(
            text("SELECT schema_name FROM public.organizations WHERE id = :id"),
            {"id": args.org_id},
        ).fetchone()
        if not row:
            print(f"ERROR: No organization with id={args.org_id} in public.organizations")
            sys.exit(1)
        schema = row[0]

    print(f"Migrating MySQL → PostgreSQL schema '{schema}' (org_id={args.org_id})")

    with mysql_engine.connect() as src_conn, pg_engine.connect() as dst_conn:
        total = 0
        for table in TABLES_IN_ORDER:
            total += migrate_table(src_conn, dst_conn, table, schema, args.batch_size)

        reset_sequences(dst_conn, schema)

    print(f"\nDone. {total} total rows migrated.")


if __name__ == "__main__":
    main()
