""" Module: services.ip_whitelist_service Description: IP whitelist management for production environments Author: Anderson H. Silva Date: 2025-01-25 License: Proprietary - All rights reserved """ import ipaddress from typing import List, Optional, Set, Dict, Any from datetime import datetime, timezone import json from src.core import get_logger from src.services.cache_service import cache_service from src.core.config import settings from src.models.base import BaseModel from sqlalchemy import Column, String, Boolean, DateTime, Integer, JSON from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, delete logger = get_logger(__name__) class IPWhitelist(BaseModel): """IP whitelist entry model.""" __tablename__ = "ip_whitelists" id = Column(String(64), primary_key=True) ip_address = Column(String(45), nullable=False, unique=True) # IPv4 or IPv6 description = Column(String(255)) environment = Column(String(20), nullable=False, default="production") active = Column(Boolean, default=True) created_by = Column(String(255), nullable=False) created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) expires_at = Column(DateTime(timezone=True), nullable=True) meta_info = Column(JSON, default=dict) # CIDR support is_cidr = Column(Boolean, default=False) cidr_prefix = Column(Integer, nullable=True) def is_expired(self) -> bool: """Check if whitelist entry is expired.""" if not self.expires_at: return False return datetime.now(timezone.utc) > self.expires_at def matches(self, ip: str) -> bool: """Check if IP matches this whitelist entry.""" if not self.active or self.is_expired(): return False try: if self.is_cidr: # CIDR range check network = ipaddress.ip_network(f"{self.ip_address}/{self.cidr_prefix}") return ipaddress.ip_address(ip) in network else: # Exact match return self.ip_address == ip except ValueError: logger.error(f"Invalid IP address format: {ip}") return False class IPWhitelistService: """Service for managing IP whitelists.""" def __init__(self): """Initialize IP whitelist service.""" self._cache_key_prefix = "ip_whitelist" self._cache_ttl = 300 # 5 minutes self._whitelist_cache: Optional[Set[str]] = None self._cidr_cache: Optional[List[tuple]] = None self._last_cache_update: Optional[datetime] = None async def add_ip( self, session: AsyncSession, ip_address: str, created_by: str, description: Optional[str] = None, environment: str = "production", expires_at: Optional[datetime] = None, is_cidr: bool = False, meta_info: Optional[Dict[str, Any]] = None ) -> IPWhitelist: """Add IP address or CIDR range to whitelist.""" try: # Parse and validate IP/CIDR if is_cidr or "/" in ip_address: network = ipaddress.ip_network(ip_address, strict=False) ip_str = str(network.network_address) cidr_prefix = network.prefixlen is_cidr = True else: # Validate single IP ip_obj = ipaddress.ip_address(ip_address) ip_str = str(ip_obj) cidr_prefix = None is_cidr = False except ValueError as e: logger.error(f"Invalid IP address format: {ip_address}") raise ValueError(f"Invalid IP address format: {ip_address}") from e # Check if already exists existing = await session.execute( select(IPWhitelist).where( IPWhitelist.ip_address == ip_str, IPWhitelist.environment == environment ) ) if existing.scalar_one_or_none(): raise ValueError(f"IP address already whitelisted: {ip_str}") # Create whitelist entry entry = IPWhitelist( id=f"{environment}:{ip_str}", ip_address=ip_str, description=description, environment=environment, created_by=created_by, expires_at=expires_at, is_cidr=is_cidr, cidr_prefix=cidr_prefix, meta_info=meta_info or {} ) session.add(entry) await session.commit() # Invalidate cache await self._invalidate_cache() logger.info( "ip_whitelist_added", ip=ip_str, environment=environment, is_cidr=is_cidr, created_by=created_by ) return entry async def remove_ip( self, session: AsyncSession, ip_address: str, environment: str = "production" ) -> bool: """Remove IP from whitelist.""" result = await session.execute( delete(IPWhitelist).where( IPWhitelist.ip_address == ip_address, IPWhitelist.environment == environment ) ) await session.commit() if result.rowcount > 0: await self._invalidate_cache() logger.info( "ip_whitelist_removed", ip=ip_address, environment=environment ) return True return False async def check_ip( self, session: AsyncSession, ip_address: str, environment: str = "production" ) -> bool: """Check if IP is whitelisted.""" # Check cache first cache_key = f"{self._cache_key_prefix}:{environment}:check:{ip_address}" cached = await cache_service.get(cache_key) if cached is not None: return cached # Load whitelist if needed await self._ensure_cache_loaded(session, environment) # Check exact matches first if self._whitelist_cache and ip_address in self._whitelist_cache: await cache_service.set(cache_key, True, ttl=self._cache_ttl) return True # Check CIDR ranges if self._cidr_cache: for cidr_ip, prefix, expires_at in self._cidr_cache: if expires_at and datetime.now(timezone.utc) > expires_at: continue try: network = ipaddress.ip_network(f"{cidr_ip}/{prefix}") if ipaddress.ip_address(ip_address) in network: await cache_service.set(cache_key, True, ttl=self._cache_ttl) return True except ValueError: continue # Not whitelisted await cache_service.set(cache_key, False, ttl=self._cache_ttl) return False async def list_ips( self, session: AsyncSession, environment: str = "production", include_expired: bool = False ) -> List[IPWhitelist]: """List all whitelisted IPs.""" query = select(IPWhitelist).where( IPWhitelist.environment == environment ) if not include_expired: now = datetime.now(timezone.utc) query = query.where( (IPWhitelist.expires_at.is_(None)) | (IPWhitelist.expires_at > now) ) result = await session.execute(query) return list(result.scalars().all()) async def update_ip( self, session: AsyncSession, ip_address: str, environment: str = "production", active: Optional[bool] = None, description: Optional[str] = None, expires_at: Optional[datetime] = None, meta_info: Optional[Dict[str, Any]] = None ) -> Optional[IPWhitelist]: """Update whitelist entry.""" result = await session.execute( select(IPWhitelist).where( IPWhitelist.ip_address == ip_address, IPWhitelist.environment == environment ) ) entry = result.scalar_one_or_none() if not entry: return None if active is not None: entry.active = active if description is not None: entry.description = description if expires_at is not None: entry.expires_at = expires_at if meta_info is not None: entry.meta_info = meta_info await session.commit() await self._invalidate_cache() logger.info( "ip_whitelist_updated", ip=ip_address, environment=environment, active=entry.active ) return entry async def cleanup_expired( self, session: AsyncSession, environment: Optional[str] = None ) -> int: """Remove expired whitelist entries.""" query = delete(IPWhitelist).where( IPWhitelist.expires_at < datetime.now(timezone.utc) ) if environment: query = query.where(IPWhitelist.environment == environment) result = await session.execute(query) await session.commit() if result.rowcount > 0: await self._invalidate_cache() logger.info( "ip_whitelist_cleanup", removed=result.rowcount, environment=environment ) return result.rowcount async def _ensure_cache_loaded( self, session: AsyncSession, environment: str ) -> None: """Ensure whitelist is loaded in cache.""" # Check if cache is still valid if ( self._last_cache_update and (datetime.now(timezone.utc) - self._last_cache_update).total_seconds() < self._cache_ttl ): return # Load from database now = datetime.now(timezone.utc) result = await session.execute( select(IPWhitelist).where( IPWhitelist.environment == environment, IPWhitelist.active == True, (IPWhitelist.expires_at.is_(None)) | (IPWhitelist.expires_at > now) ) ) entries = result.scalars().all() # Separate exact IPs and CIDR ranges self._whitelist_cache = set() self._cidr_cache = [] for entry in entries: if entry.is_cidr: self._cidr_cache.append(( entry.ip_address, entry.cidr_prefix, entry.expires_at )) else: self._whitelist_cache.add(entry.ip_address) self._last_cache_update = datetime.now(timezone.utc) logger.debug( "ip_whitelist_cache_loaded", environment=environment, exact_ips=len(self._whitelist_cache), cidr_ranges=len(self._cidr_cache) ) async def _invalidate_cache(self) -> None: """Invalidate the whitelist cache.""" self._whitelist_cache = None self._cidr_cache = None self._last_cache_update = None # Clear Redis cache patterns pattern = f"{self._cache_key_prefix}:*" await cache_service.delete_pattern(pattern) def get_default_whitelist(self) -> List[str]: """Get default whitelist based on environment.""" defaults = [] # Always allow localhost defaults.extend([ "127.0.0.1", "::1", "localhost" ]) # Development environment if settings.is_development: defaults.extend([ "10.0.0.0/8", # Private network "172.16.0.0/12", # Private network "192.168.0.0/16" # Private network ]) # Production environment - add known services if settings.is_production: # Vercel IPs (example - would need real ranges) defaults.extend([ "76.76.21.0/24", # Vercel edge network (example) "76.223.0.0/16" # Vercel edge network (example) ]) # HuggingFace Spaces IPs (example - would need real ranges) defaults.extend([ "34.0.0.0/8", # Google Cloud (where HF runs) "35.0.0.0/8" # Google Cloud ]) # Monitoring services defaults.extend([ "52.0.0.0/8" # AWS (for monitoring) ]) return defaults async def initialize_defaults( self, session: AsyncSession, created_by: str = "system" ) -> int: """Initialize default whitelist entries.""" defaults = self.get_default_whitelist() count = 0 for ip in defaults: try: is_cidr = "/" in ip await self.add_ip( session=session, ip_address=ip, created_by=created_by, description="Default whitelist entry", environment=settings.app_env, is_cidr=is_cidr ) count += 1 except ValueError: # Already exists or invalid continue logger.info( "ip_whitelist_defaults_initialized", count=count, environment=settings.app_env ) return count # Global instance ip_whitelist_service = IPWhitelistService()