"""Authentication utilities for Aera"""
import hashlib  # <--- NEW IMPORT
from datetime import datetime, timedelta, timezone
from typing import Optional, List
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
 
# Import your new Schema and Model structure
from schemas import TokenData, UserRole
from models import User
from config import settings
from database import get_db
 
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
 
def verify_password(plain_password: str, hashed_password: str) -> bool:
    """
    Verifies a password against the stored hash.
    Fixes 72-byte limit warning by pre-hashing with SHA-256.
    """
    # 1. Pre-hash with SHA-256 to get a fixed 64-char string
    password_digest = hashlib.sha256(plain_password.encode('utf-8')).hexdigest()
    # 2. Verify the digest against the stored bcrypt hash
    return pwd_context.verify(password_digest, hashed_password)
 
def get_password_hash(password: str) -> str:
    """
    Generates a bcrypt hash for the password.
    Fixes 72-byte limit warning by pre-hashing with SHA-256.
    """
    # 1. Pre-hash with SHA-256 to get a fixed 64-char string
    password_digest = hashlib.sha256(password.encode('utf-8')).hexdigest()
    # 2. Hash the digest
    return pwd_context.hash(password_digest)
 
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
    to_encode = data.copy()
    expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
    to_encode.update({"exp": expire, "type": "access"})
    return jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
 
def create_refresh_token(data: dict) -> str:
    to_encode = data.copy()
    expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
    to_encode.update({"exp": expire, "type": "refresh"})
    return jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
 
def decode_token(token: str, token_type: str = "access") -> Optional[TokenData]:
    try:
        payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
        if payload.get("type") != token_type:
            return None
        user_id: str = payload.get("sub")
        email: str = payload.get("email")
        role: str = payload.get("role")
        if user_id is None:
            return None
        return TokenData(user_id=user_id, email=email, role=role)
    except JWTError:
        return None
 
# --- DATABASE DEPENDENCY INJECTION ---
 
async def get_current_user(
    token: str = Depends(oauth2_scheme), 
    db: AsyncSession = Depends(get_db)
) -> User:
    """
    Validates the token and retrieves the User object from MySQL.
    """
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    token_data = decode_token(token)
    if token_data is None:
        raise credentials_exception
    # SQLAlchemy Query
    result = await db.execute(select(User).where(User.id == token_data.user_id))
    user = result.scalars().first()
    if user is None:
        raise credentials_exception
    return user
 
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
    # Access attributes using dot notation (SQLAlchemy Model), not ["key"]
    if not current_user.is_active:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user
 
def require_role(allowed_roles: List[UserRole]):
    async def role_checker(current_user: User = Depends(get_current_active_user)) -> User:
        # Check against the Enum value
        if current_user.role not in [r.value for r in allowed_roles]:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail="Insufficient permissions"
            )
        return current_user
    return role_checker
