#!/usr/bin/env python3
"""
EQUINOX Sensor Agent — lightweight single-file sensor.

Installs on Windows / Linux / macOS endpoints to register with the
EQUINOX security platform and stream telemetry.

Usage:
    python3 equinox-sensor.py              # Interactive setup
    python3 equinox-sensor.py --register   # Force re-registration
    python3 equinox-sensor.py --help       # Full usage
"""

import json
import logging
import os
import platform as plat
import signal
import socket
import subprocess
import sys
import time
import uuid
from pathlib import Path
from typing import Any, Dict, Optional

try:
    import psutil
except ImportError:
    psutil = None  # type: ignore

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
CONFIG_DIR = Path.home() / ".equinox"
CONFIG_PATH = CONFIG_DIR / "config.json"
LOG_PATH = CONFIG_DIR / "sensor.log"
DEFAULT_API_URL = "https://equinoxsec.com"
CHECK_INTERVAL = 60  # seconds

# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logger = logging.getLogger("equinox-sensor")

def setup_logging() -> None:
    CONFIG_DIR.mkdir(parents=True, exist_ok=True)
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[
            logging.FileHandler(str(LOG_PATH)),
            logging.StreamHandler(sys.stdout),
        ],
    )

# ---------------------------------------------------------------------------
# Config management
# ---------------------------------------------------------------------------

def load_config() -> Dict[str, Any]:
    if CONFIG_PATH.exists():
        return json.loads(CONFIG_PATH.read_text())
    return {}

def save_config(cfg: Dict[str, Any]) -> None:
    CONFIG_DIR.mkdir(parents=True, exist_ok=True)
    CONFIG_PATH.write_text(json.dumps(cfg, indent=2))
    CONFIG_PATH.chmod(0o600)

def prompt_config() -> Dict[str, Any]:
    """Interactive first-run configuration."""
    print("\n=== EQUINOX Sensor Agent Setup ===\n")

    api_url = input(f"EQUINOX API URL [{DEFAULT_API_URL}]: ").strip() or DEFAULT_API_URL
    api_url = api_url.rstrip("/")

    api_key = ""
    while not api_key:
        api_key = input("Tenant API Key (from Equinox dashboard): ").strip()
        if not api_key:
            print("  API Key is required.")

    return {
        "api_url": api_url,
        "api_key": api_key,
        "check_interval": CHECK_INTERVAL,
        "endpoint_id": str(uuid.uuid4()),
        "registered": False,
    }

# ---------------------------------------------------------------------------
# System info collection
# ---------------------------------------------------------------------------

def get_hostname() -> str:
    return socket.gethostname()

def get_ip_address() -> str:
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        s.settimeout(2)
        s.connect(("10.255.255.255", 1))
        ip = s.getsockname()[0]
        s.close()
        return ip
    except Exception:
        return "0.0.0.0"

def get_os_info() -> str:
    return f"{plat.system()} {plat.release()} ({plat.version()})"

def get_cpu_usage() -> float:
    if psutil:
        return psutil.cpu_percent(interval=0.5)
    return 0.0

def get_memory_usage() -> Dict[str, Any]:
    if psutil:
        mem = psutil.virtual_memory()
        return {"total_gb": round(mem.total / (1024 ** 3), 2),
                "used_gb": round(mem.used / (1024 ** 3), 2),
                "percent": mem.percent}
    return {"total_gb": 0, "used_gb": 0, "percent": 0}

def get_disk_usage() -> Dict[str, Any]:
    if psutil:
        du = psutil.disk_usage("/")
        return {"total_gb": round(du.total / (1024 ** 3), 2),
                "used_gb": round(du.used / (1024 ** 3), 2),
                "percent": du.percent}
    return {"total_gb": 0, "used_gb": 0, "percent": 0}

def get_top_processes(n: int = 20) -> list:
    if not psutil:
        return []
    procs = []
    for p in psutil.process_iter(["pid", "name", "cpu_percent", "memory_percent",
                                   "username", "status"]):
        try:
            info = p.info
            info["cpu_percent"] = info.get("cpu_percent") or 0.0
            info["memory_percent"] = info.get("memory_percent") or 0.0
            procs.append(info)
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            continue
    procs.sort(key=lambda x: x["cpu_percent"], reverse=True)
    return procs[:n]

def collect_system_info() -> Dict[str, Any]:
    return {
        "hostname": get_hostname(),
        "ip_address": get_ip_address(),
        "os": get_os_info(),
        "cpu_percent": get_cpu_usage(),
        "memory": get_memory_usage(),
        "disk": get_disk_usage(),
        "processes": get_top_processes(),
        "timestamp": time.time(),
    }

# ---------------------------------------------------------------------------
# API calls
# ---------------------------------------------------------------------------

def api_post(url: str, api_key: str, endpoint: str, payload: dict, timeout: int = 15) -> Optional[dict]:
    """Make an HTTP POST with JSON body. Returns parsed JSON or None."""
    import urllib.request
    import urllib.error

    full_url = f"{url}/api/v1/sensor/{endpoint}"
    data = json.dumps(payload).encode("utf-8")
    req = urllib.request.Request(
        full_url,
        data=data,
        headers={
            "Content-Type": "application/json",
            "X-API-Key": api_key,
        },
        method="POST",
    )
    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            body = resp.read().decode("utf-8")
            logger.info("API %s -> %s", full_url, resp.status)
            return json.loads(body) if body else {}
    except urllib.error.HTTPError as e:
        logger.warning("API POST %s failed: HTTP %d %s", full_url, e.code, e.reason)
        return None
    except urllib.error.URLError as e:
        logger.warning("API POST %s failed: %s", full_url, e.reason)
        return None
    except json.JSONDecodeError:
        logger.warning("API POST %s: invalid JSON response", full_url)
        return None

def register(cfg: Dict[str, Any]) -> bool:
    """Register with Equinox backend."""
    logger.info("Registering with %s ...", cfg["api_url"])
    sys_info = collect_system_info()
    payload = {
        "agent_secret": cfg["api_key"],
        "hostname": sys_info["hostname"],
        "ip_address": sys_info["ip_address"],
        "os": sys_info["os"],
        "agent_version": "1.0.0",
    }
    resp = api_post(cfg["api_url"], cfg["api_key"], "register", payload)
    if resp:
        cfg["registered"] = True
        cfg["endpoint_id"] = resp.get("endpoint_id", cfg.get("endpoint_id"))
        save_config(cfg)
        logger.info("Registration successful. Endpoint ID: %s", cfg["endpoint_id"])
        return True
    logger.warning("Registration failed — will retry next cycle.")
    return False

def send_heartbeat(cfg: Dict[str, Any]) -> bool:
    """Send heartbeat to keep registration alive."""
    payload = {
        "endpoint_id": cfg["endpoint_id"],
        "agent_secret": cfg["api_key"],
        "ip_address": get_ip_address(),
        "agent_version": "1.0.0",
    }
    resp = api_post(cfg["api_url"], cfg["api_key"], "heartbeat", payload)
    return resp is not None

def send_offline(cfg: Dict[str, Any]) -> None:
    """Signal offline status on shutdown."""
    payload = {
        "endpoint_id": cfg["endpoint_id"],
        "agent_secret": cfg["api_key"],
        "status": "offline",
    }
    api_post(cfg["api_url"], cfg["api_key"], "heartbeat", payload)
    logger.info("Sent offline status.")

def send_telemetry(cfg: Dict[str, Any]) -> bool:
    """Send collected system telemetry."""
    telemetry = collect_system_info()
    payload = {
        "endpoint_id": cfg["endpoint_id"],
        "agent_secret": cfg["api_key"],
        "telemetry": telemetry,
    }
    resp = api_post(cfg["api_url"], cfg["api_key"], "telemetry", payload)
    return resp is not None

# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------

_shutdown = False

def handle_signal(signum: int, frame) -> None:
    global _shutdown
    logger.info("Signal %d received — shutting down...", signum)
    _shutdown = True

def main() -> None:
    global _shutdown
    setup_logging()

    # Parse args
    force_register = "--register" in sys.argv
    if "--help" in sys.argv or "-h" in sys.argv:
        print(__doc__)
        sys.exit(0)

    logger.info("EQUINOX Sensor Agent starting...")

    # Load or create config
    cfg = load_config()
    if not cfg or force_register:
        cfg = prompt_config()
        save_config(cfg)

    # Register signals
    signal.signal(signal.SIGTERM, handle_signal)
    signal.signal(signal.SIGINT, handle_signal)

    # Register with backend if not already
    if not cfg.get("registered") or force_register:
        if not register(cfg):
            logger.warning("Initial registration failed, continuing in offline mode.")

    # Main loop
    logger.info("Agent running. Check interval: %d seconds", cfg.get("check_interval", CHECK_INTERVAL))
    while not _shutdown:
        cycle_start = time.time()

        if cfg.get("registered"):
            send_heartbeat(cfg)
            send_telemetry(cfg)
        elif not _shutdown:
            logger.info("Not registered — attempting registration...")
            register(cfg)

        elapsed = time.time() - cycle_start
        sleep_time = max(0, cfg.get("check_interval", CHECK_INTERVAL) - elapsed)
        if sleep_time > 0 and not _shutdown:
            time.sleep(min(sleep_time, 1))

    # Clean shutdown
    if cfg.get("registered"):
        send_offline(cfg)
    logger.info("EQUINOX Sensor Agent stopped.")


if __name__ == "__main__":
    main()
