"""Fix ai_usage_log schema to match current model

The table was created with legacy columns (action, prompt_tokens, completion_tokens,
total_tokens). The model now uses call_type, tokens_used, used_own_key.

Revision ID: aa1fix
Revises:
Create Date: 2026-05-05
"""
revision = 'aa1fix'
down_revision = ('pr001', 'z7a8b9c0d1e2')
branch_labels = None
depends_on = None

from alembic import op
import sqlalchemy as sa


def upgrade():
    conn = op.get_bind()

    # Check which columns exist
    result = conn.execute(sa.text(
        "SELECT column_name FROM information_schema.columns "
        "WHERE table_name='ai_usage_log' AND table_schema=current_schema()"
    ))
    existing = {r[0] for r in result}

    # Add call_type (from 'action' if exists, else new column)
    if 'call_type' not in existing:
        if 'action' in existing:
            op.alter_column('ai_usage_log', 'action', new_column_name='call_type')
        else:
            op.add_column('ai_usage_log', sa.Column('call_type', sa.String(20), nullable=False, server_default='classify'))

    # Add tokens_used (sum of prompt+completion if they exist)
    if 'tokens_used' not in existing:
        op.add_column('ai_usage_log', sa.Column('tokens_used', sa.Integer, nullable=False, server_default='0'))
        if 'prompt_tokens' in existing and 'completion_tokens' in existing:
            conn.execute(sa.text(
                "UPDATE ai_usage_log SET tokens_used = COALESCE(prompt_tokens, 0) + COALESCE(completion_tokens, 0)"
            ))

    # Add used_own_key
    if 'used_own_key' not in existing:
        op.add_column('ai_usage_log', sa.Column('used_own_key', sa.Boolean, nullable=False, server_default='false'))

    # Drop legacy columns
    for col in ('prompt_tokens', 'completion_tokens', 'total_tokens'):
        if col in existing:
            op.drop_column('ai_usage_log', col)

    # Remove server defaults now that data is populated
    op.alter_column('ai_usage_log', 'call_type', server_default=None)
    op.alter_column('ai_usage_log', 'tokens_used', server_default=None)
    op.alter_column('ai_usage_log', 'used_own_key', server_default=None)


def downgrade():
    pass
