"""Dashboard Router - Stats and Analytics (MySQL/SQLAlchemy)"""
from fastapi import APIRouter, Depends
from typing import List
from datetime import datetime, timezone, timedelta
from sqlalchemy import func, select, desc
from sqlalchemy.ext.asyncio import AsyncSession

from auth import get_current_active_user
from database import get_db
from models import Conversation, RiskAnalysis, ReviewAction, User
from schemas import DashboardStats, RiskTrendItem

router = APIRouter(prefix="/dashboard", tags=["Dashboard"])

@router.get("/stats", response_model=DashboardStats)
async def get_dashboard_stats(
    current_user: User = Depends(get_current_active_user),
    db: AsyncSession = Depends(get_db)
):
    """Get dashboard statistics"""
    
    # Helper to run count queries
    async def count_records(model, filter_condition=None):
        stmt = select(func.count()).select_from(model)
        if filter_condition is not None:
            stmt = stmt.where(filter_condition)
        return await db.scalar(stmt)

    # 1. Total conversations
    total_conversations = await count_records(Conversation)
    
    # 2. Risk counts
    total_flagged = await count_records(RiskAnalysis)
    high_risk_count = await count_records(RiskAnalysis, RiskAnalysis.risk_level == "HIGH")
    medium_risk_count = await count_records(RiskAnalysis, RiskAnalysis.risk_level == "MEDIUM")
    low_risk_count = await count_records(RiskAnalysis, RiskAnalysis.risk_level == "LOW")
    
    # 3. Pending reviews
    pending_reviews = await count_records(RiskAnalysis, RiskAnalysis.status == "pending")
    
    # 4. Reviewed today
    today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
    reviewed_today = await count_records(ReviewAction, ReviewAction.timestamp >= today_start)
    
    return DashboardStats(
        total_conversations=total_conversations or 0,
        total_flagged=total_flagged or 0,
        high_risk_count=high_risk_count or 0,
        medium_risk_count=medium_risk_count or 0,
        low_risk_count=low_risk_count or 0,
        pending_reviews=pending_reviews or 0,
        reviewed_today=reviewed_today or 0
    )

@router.get("/trends", response_model=List[RiskTrendItem])
async def get_risk_trends(
    current_user: User = Depends(get_current_active_user),
    db: AsyncSession = Depends(get_db)
):
    """Get risk trends for the last 7 days"""
    trends = []
    
    # We loop through days in Python to ensure we get 0s for empty days
    # (SQL Group By can skip days with no data)
    for i in range(7):
        date = datetime.utcnow() - timedelta(days=i)
        date_str = date.strftime("%Y-%m-%d")
        day_start = date.replace(hour=0, minute=0, second=0, microsecond=0)
        day_end = day_start + timedelta(days=1)
        
        # Helper for efficient range counting
        async def count_range(level):
            stmt = select(func.count()).select_from(RiskAnalysis).where(
                RiskAnalysis.risk_level == level,
                RiskAnalysis.analyzed_at >= day_start,
                RiskAnalysis.analyzed_at < day_end
            )
            return await db.scalar(stmt) or 0
        
        high = await count_range("HIGH")
        medium = await count_range("MEDIUM")
        low = await count_range("LOW")
        
        trends.append(RiskTrendItem(
            date=date_str,
            high=high,
            medium=medium,
            low=low
        ))
    
    # Reverse to show oldest first
    return list(reversed(trends))

@router.get("/recent-alerts")
async def get_recent_alerts(
    limit: int = 5,
    current_user: User = Depends(get_current_active_user),
    db: AsyncSession = Depends(get_db)
):
    """Get recent high-risk alerts with conversation details"""
    from sqlalchemy.orm import selectinload
    
    stmt = select(RiskAnalysis)\
        .options(selectinload(RiskAnalysis.conversation))\
        .where(RiskAnalysis.risk_level.in_(["HIGH", "MEDIUM"]))\
        .order_by(desc(RiskAnalysis.analyzed_at))\
        .limit(limit)
        
    result = await db.execute(stmt)
    alerts = result.scalars().all()
    
    # We return the ORM objects directly; FastAPI/Pydantic will serialize them
    # provided the schema matches. Since we need nested data, ensure
    # the frontend receives the JSON structure it expects.
    return alerts