"""Sync Service - The Brain of Aera Risk Detection (MySQL/SQLAlchemy)"""
import logging
from typing import Dict, Any, List
from datetime import datetime, timezone
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import desc

# Import SQLAlchemy Models
from models import (
    Conversation, Message, RiskAnalysis, AuditLog
)
from schemas import RiskLevel

# Import Services
from services.graph_service import graph_service
from services.claude_service import claude_service

logger = logging.getLogger(__name__)

class SyncService:
    """Service for syncing and analyzing Microsoft Teams conversations"""
    
    async def sync_and_analyze(self, db: AsyncSession, user_id: str, user_email: str) -> Dict[str, Any]:
        logger.info(f"Starting sync for user {user_email}")
        
        await self._create_audit_log(db, user_id, user_email, "SYNC_INITIATED", {})
        
        conversations_synced = 0
        risks_detected = 0
        
        try:
            # Step 1: Fetch EVERYTHING
            chats_data = await graph_service.fetch_conversations_and_messages()
            logger.info(f"Processing {len(chats_data)} conversations...")
            
            for chat_data in chats_data:
                # Step 2: Store Conversation
                conversation_record = await self._store_conversation(db, chat_data)
                
                # Step 3: Store Messages (With Timezone Fix)
                msgs_data = chat_data.get("messages", [])
                if not msgs_data:
                    continue
                    
                # Store messages (Deduplication handles the rest)
                await self._store_messages(db, conversation_record.id, msgs_data)
                
                # FORCE ANALYSIS: We removed the "skip" logic so we always check for risks
                
                conversations_synced += 1
                
                # Step 4: Build Cross-Chat Context
                participants = chat_data.get("participants", [])
                employee_emails = [
                    p.get("email") for p in participants
                    if "company.com" in p.get("email", "") or "graycell.com" in p.get("email", "")
                ]
                
                cross_chat_context = ""
                primary_employee = employee_emails[0] if employee_emails else None
                
                if primary_employee:
                    other_convs = await self._get_employee_recent_conversations(
                        db,
                        primary_employee,
                        exclude_conversation_id=conversation_record.id
                    )
                    
                    if other_convs:
                        cross_chat_context = await claude_service.generate_cross_chat_summary(
                            primary_employee,
                            other_convs
                        )
                
                # Step 5: Analyze with Claude (Full Context)
                all_msgs_stmt = select(Message).where(Message.conversation_id == conversation_record.id).order_by(Message.timestamp)
                all_msgs_result = await db.execute(all_msgs_stmt)
                all_db_messages = all_msgs_result.scalars().all()

                serialized_messages = [
                    {"sender_name": m.sender_name, "content": m.content, "timestamp": m.timestamp.isoformat()} 
                    for m in all_db_messages
                ]
                
                analysis_result = await claude_service.analyze_conversation(
                    conversation_messages=serialized_messages,
                    participants=participants,
                    cross_chat_context=cross_chat_context
                )
                
                # Step 6: Store Risk Analysis
                risk_lvl_str = analysis_result["risk_level"]
                if hasattr(risk_lvl_str, 'value'):
                    risk_lvl_str = risk_lvl_str.value
                
                # --- FILTER: ONLY SAVE MEDIUM OR HIGH RISKS ---
                if risk_lvl_str in ["MEDIUM", "HIGH"]:
                    await self._store_risk_analysis(
                        db,
                        conversation_record.id,
                        analysis_result,
                        cross_chat_context
                    )
                    risks_detected += 1
                    logger.info(f"🚨 Flagged conversation {conversation_record.id} as {risk_lvl_str}")
            
            await self._create_audit_log(
                db,
                user_id, user_email, "SYNC_COMPLETED",
                {"conversations_synced": conversations_synced, "risks_detected": risks_detected}
            )
            
            return {
                "success": True,
                "conversations_synced": conversations_synced,
                "risks_detected": risks_detected,
                "message": f"Synced {conversations_synced} chats, found {risks_detected} risks."
            }
            
        except Exception as e:
            logger.error(f"Sync error: {str(e)}")
            return {
                "success": False,
                "conversations_synced": conversations_synced,
                "risks_detected": risks_detected,
                "message": f"Sync failed: {str(e)}"
            }
    
    async def _store_conversation(self, db: AsyncSession, chat_data: Dict[str, Any]) -> Conversation:
        teams_thread_id = chat_data.get("id")
        result = await db.execute(select(Conversation).where(Conversation.teams_thread_id == teams_thread_id))
        existing = result.scalars().first()
        
        participants = chat_data.get("participants", [])
        
        if existing:
            existing.last_synced_at = datetime.utcnow()
            existing.participants = participants
            db.add(existing)
            await db.commit()
            return existing
        
        new_conv = Conversation(
            teams_thread_id=teams_thread_id,
            participants=participants,
            topic=chat_data.get("topic", "Teams Chat")
        )
        db.add(new_conv)
        await db.commit()
        await db.refresh(new_conv)
        return new_conv
    
    async def _store_messages(self, db: AsyncSession, conversation_id: str, raw_messages: List[Dict[str, Any]]) -> List[Message]:
        """Store messages with strict deduplication (Timezone Naive Fix)"""
        newly_added = []
        
        for msg in raw_messages:
            created_at_str = msg.get("created_at")
            if isinstance(created_at_str, str):
                created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
            else:
                created_at = datetime.utcnow()
            
            if created_at.tzinfo is not None:
                created_at = created_at.astimezone(timezone.utc).replace(tzinfo=None)
            
            created_at = created_at.replace(microsecond=0)

            stmt = select(Message).where(
                Message.conversation_id == conversation_id,
                Message.content == msg.get("content"),
                Message.timestamp == created_at
            )
            result = await db.execute(stmt)
            existing = result.scalars().first()

            if existing:
                continue

            new_msg = Message(
                conversation_id=conversation_id,
                sender_email=msg.get("sender_email"),
                sender_name=msg.get("sender_name"),
                content=msg.get("content"),
                timestamp=created_at
            )
            db.add(new_msg)
            newly_added.append(new_msg)
        
        if newly_added:
            await db.commit()
        return newly_added
    
    async def _get_employee_recent_conversations(self, db: AsyncSession, employee_email: str, exclude_conversation_id: str) -> List[Dict[str, Any]]:
        stmt = select(Conversation).where(Conversation.id != exclude_conversation_id).order_by(desc(Conversation.last_synced_at)).limit(20)
        result = await db.execute(stmt)
        candidates = result.scalars().all()
        
        relevant_convs = []
        for conv in candidates:
            parts = conv.participants or []
            is_involved = any(p.get("email") == employee_email for p in parts)
            
            if is_involved:
                msg_stmt = select(Message).where(Message.conversation_id == conv.id).order_by(desc(Message.timestamp)).limit(5)
                msg_result = await db.execute(msg_stmt)
                messages = msg_result.scalars().all()
                messages.reverse()
                
                relevant_convs.append({
                    "id": conv.id,
                    "topic": conv.topic,
                    "messages": [{"sender": m.sender_name, "content": m.content} for m in messages]
                })
                
                if len(relevant_convs) >= 3:
                    break
        return relevant_convs
    
    async def _store_risk_analysis(self, db: AsyncSession, conversation_id: str, analysis_result: Dict[str, Any], cross_chat_context: str):
        stmt = select(RiskAnalysis).where(RiskAnalysis.conversation_id == conversation_id)
        result = await db.execute(stmt)
        existing = result.scalars().first()
        
        risk_level_val = analysis_result["risk_level"]
        if hasattr(risk_level_val, 'value'):
            risk_level_val = risk_level_val.value

        if existing:
            existing.risk_score = analysis_result["risk_score"]
            existing.risk_level = risk_level_val
            existing.ai_explanation = analysis_result["ai_explanation"]
            existing.risk_indicators = analysis_result.get("risk_indicators", [])
            existing.cross_chat_context = cross_chat_context
            existing.analyzed_at = datetime.utcnow()
            db.add(existing)
        else:
            new_analysis = RiskAnalysis(
                conversation_id=conversation_id,
                risk_score=analysis_result["risk_score"],
                risk_level=risk_level_val,
                ai_explanation=analysis_result["ai_explanation"],
                risk_indicators=analysis_result.get("risk_indicators", []),
                cross_chat_context=cross_chat_context,
                status="pending"
            )
            db.add(new_analysis)
        await db.commit()

    async def _create_audit_log(self, db: AsyncSession, actor_id: str, actor_email: str, action: str, details: Dict[str, Any]):
        audit_log = AuditLog(
            actor_id=actor_id,
            actor_email=actor_email,
            action=action,
            details=details,
            timestamp=datetime.utcnow()
        )
        db.add(audit_log)
        await db.commit()

sync_service = SyncService()