"""RAG-Lite Chatbot Service for Querying Risk Data"""
import anthropic
import logging
from typing import Dict, Any, List
from config import settings
from database import (
    risk_analyses_collection,
    conversations_collection,
    messages_collection,
    review_actions_collection,
    audit_logs_collection
)

logger = logging.getLogger(__name__)

class RAGService:
    """RAG-Lite service for chatbot queries"""
    
    def __init__(self):
        self.client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
        self.model = "claude-sonnet-4-20250514"
    
    async def query(self, user_query: str) -> Dict[str, Any]:
        """
        Process a user query using RAG-lite approach:
        1. Retrieve relevant data from database
        2. Synthesize response using Claude
        """
        # Step 1: Retrieve relevant data
        retrieved_data = await self._retrieve_relevant_data(user_query)
        
        # Step 2: Synthesize response
        response = await self._synthesize_response(user_query, retrieved_data)
        
        return response
    
    async def _retrieve_relevant_data(self, query: str) -> Dict[str, Any]:
        """Retrieve relevant data from database based on query"""
        query_lower = query.lower()
        retrieved = {
            "risk_analyses": [],
            "conversations": [],
            "review_actions": [],
            "audit_logs": []
        }
        
        # Extract potential entity names from query
        # Simple approach: look for names/emails mentioned
        
        # Get high and medium risk analyses
        risk_filter = {}
        if "high risk" in query_lower or "critical" in query_lower:
            risk_filter["risk_level"] = "HIGH"
        elif "medium" in query_lower:
            risk_filter["risk_level"] = "MEDIUM"
        
        risk_analyses = await risk_analyses_collection.find(
            risk_filter, {"_id": 0}
        ).sort("analyzed_at", -1).limit(10).to_list(10)
        
        # Get conversation details for each risk analysis
        for analysis in risk_analyses:
            conv = await conversations_collection.find_one(
                {"id": analysis.get("conversation_id")},
                {"_id": 0}
            )
            if conv:
                analysis["conversation"] = conv
                # Check if any participant name matches query
                for participant in conv.get("participants", []):
                    name = participant.get("name", "").lower()
                    email = participant.get("email", "").lower()
                    if name in query_lower or email.split("@")[0] in query_lower:
                        analysis["name_match"] = True
        
        retrieved["risk_analyses"] = risk_analyses
        
        # Get recent review actions
        review_actions = await review_actions_collection.find(
            {}, {"_id": 0}
        ).sort("timestamp", -1).limit(5).to_list(5)
        retrieved["review_actions"] = review_actions
        
        # Get recent audit logs if query mentions audit/activity
        if "audit" in query_lower or "activity" in query_lower or "log" in query_lower:
            audit_logs = await audit_logs_collection.find(
                {}, {"_id": 0}
            ).sort("timestamp", -1).limit(10).to_list(10)
            retrieved["audit_logs"] = audit_logs
        
        return retrieved
    
    async def _synthesize_response(
        self, 
        query: str, 
        retrieved_data: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Synthesize a response using Claude"""
        
        system_prompt = """You are the Aera AI Assistant, an expert compliance analysis assistant. You help compliance officers understand risk patterns and investigate potential policy violations in Microsoft Teams communications.

Your role:
- Answer questions about flagged conversations and risk analyses
- Provide evidence-based insights from the retrieved data
- Be concise but thorough
- Always cite specific data when making claims
- If you don't have enough data to answer, say so clearly

IMPORTANT: Only use information from the provided context. Do not make up data."""

        # Format retrieved data for context
        context = self._format_context(retrieved_data)
        
        user_prompt = f"""User Query: {query}

RETRIEVED CONTEXT:
{context}

Based on the above context, please answer the user's query. Be specific and cite evidence from the data."""

        try:
            response = self.client.messages.create(
                model=self.model,
                max_tokens=1000,
                messages=[
                    {"role": "user", "content": user_prompt}
                ],
                system=system_prompt
            )
            
            response_text = response.content[0].text.strip()
            
            # Extract sources from retrieved data
            sources = []
            for analysis in retrieved_data.get("risk_analyses", [])[:3]:
                if analysis.get("conversation"):
                    participants = [p.get("name", p.get("email", "Unknown")) 
                                  for p in analysis["conversation"].get("participants", [])]
                    sources.append({
                        "type": "risk_analysis",
                        "risk_level": analysis.get("risk_level"),
                        "participants": participants,
                        "id": analysis.get("id")
                    })
            
            return {
                "response": response_text,
                "sources": sources
            }
            
        except Exception as e:
            logger.error(f"RAG synthesis error: {str(e)}")
            return {
                "response": f"I encountered an error processing your query: {str(e)}",
                "sources": []
            }
    
    def _format_context(self, retrieved_data: Dict[str, Any]) -> str:
        """Format retrieved data as context for Claude"""
        sections = []
        
        # Risk Analyses
        if retrieved_data.get("risk_analyses"):
            analyses_text = "RISK ANALYSES:\n"
            for i, analysis in enumerate(retrieved_data["risk_analyses"], 1):
                conv = analysis.get("conversation", {})
                participants = [p.get("name", p.get("email", "Unknown")) 
                              for p in conv.get("participants", [])]
                analyses_text += f"""
{i}. Risk Level: {analysis.get('risk_level')} (Score: {analysis.get('risk_score')})
   Participants: {', '.join(participants)}
   Status: {analysis.get('status', 'pending')}
   Explanation: {analysis.get('ai_explanation', 'N/A')[:300]}...
   Indicators: {', '.join(analysis.get('risk_indicators', [])[:3])}
"""
            sections.append(analyses_text)
        
        # Review Actions
        if retrieved_data.get("review_actions"):
            actions_text = "RECENT REVIEW ACTIONS:\n"
            for action in retrieved_data["review_actions"]:
                actions_text += f"- {action.get('action')} by {action.get('reviewer_name')} on {action.get('timestamp')}\n"
            sections.append(actions_text)
        
        # Audit Logs
        if retrieved_data.get("audit_logs"):
            logs_text = "AUDIT LOGS:\n"
            for log in retrieved_data["audit_logs"]:
                logs_text += f"- {log.get('action')} by {log.get('actor_email')} at {log.get('timestamp')}\n"
            sections.append(logs_text)
        
        if not sections:
            return "No relevant data found in the system."
        
        return "\n".join(sections)

rag_service = RAGService()
