Files
hosts_daemon/app.py
2025-09-05 09:18:32 +02:00

371 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import os, time, logging, platform, psutil, configparser, subprocess, shutil
from pathlib import Path
from flask import Flask, request, jsonify, abort
from flask_sslify import SSLify
from datetime import datetime, timezone
app = Flask(__name__)
sslify = SSLify(app)
# --- ŚCIEŻKI / ŚRODOWISKO ---
LOG_DIR = os.environ.get("HOSTS_DAEMON_LOG_DIR", "logs")
TOKEN_FILE_PATH = os.environ.get("HOSTS_DAEMON_TOKEN_FILE", "daemon_token.txt")
CONFIG_PATH = os.environ.get("HOSTS_DAEMON_CONFIG", "config.ini")
os.makedirs(LOG_DIR, exist_ok=True)
LOG_FILE = os.path.join(LOG_DIR, "daemon.log")
# --- LOGGING ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.FileHandler(LOG_FILE), logging.StreamHandler()]
)
logger = logging.getLogger("hosts_daemon")
# --- CONFIG ---
cfg = configparser.ConfigParser()
read_ok = cfg.read(CONFIG_PATH)
if not read_ok:
logger.warning(f"Brak pliku konfiguracyjnego {CONFIG_PATH} lub nieczytelny.")
# --- METRYKI ---
metrics = {
"total_requests": 0,
"total_time": 0.0,
"endpoints": {},
"hosts_get": 0,
"hosts_post": 0,
}
# ------------------
# FUNKCJE POMOCNICZE
# ------------------
def getCfg(key: str, default=None):
"""Pobiera wartość z sekcji [daemon] lub zwraca domyślną."""
return cfg.get("daemon", key, fallback=default)
RELOAD_TIMEOUT = int(getCfg("reload_timeout", "5") or 5)
def readTokenFromFile(path: str):
p = Path(path)
try:
if not p.is_file():
return None
content = p.read_text(encoding="utf-8").strip()
return content if content else None
except (PermissionError, IsADirectoryError, OSError, UnicodeDecodeError) as e:
logger.error(f"Nie udało się odczytać pliku tokenu '{path}': {e}")
return None
def listServices():
"""Zwraca listę serwisów do przeładowania z klucza 'services'."""
raw = getCfg("services", "")
return [s.strip() for s in raw.split(",") if s.strip()]
def serviceCommand(name: str):
"""Zwraca komendę dla danego serwisu (sekcja [service:<name>] ma priorytet)."""
sect = f"service:{name}"
if cfg.has_section(sect) and cfg.has_option(sect, "command"):
return cfg.get(sect, "command")
return f"systemctl reload {name}"
def runCmd(cmd: str):
"""Uruchamia komendę w shellu z timeoutem i logowaniem."""
logger.info(f"Exec: {cmd}")
try:
out = subprocess.run(
cmd, shell=True, capture_output=True, text=True, timeout=RELOAD_TIMEOUT
)
stdout = (out.stdout or "").strip()
stderr = (out.stderr or "").strip()
return out.returncode, stdout, stderr
except subprocess.TimeoutExpired:
return 124, "", "Timeout"
def reloadService(name: str):
"""Próbuje reload; gdy nie powiedzie się standardowe 'systemctl reload', robi restart."""
cmd = serviceCommand(name)
rc, out, err = runCmd(cmd)
if rc == 0:
return {"service": name, "action": "reload", "rc": rc, "stdout": out, "stderr": err}
if cmd.startswith("systemctl reload"):
rc2, out2, err2 = runCmd(f"systemctl restart {name}")
return {
"service": name,
"action": "restart" if rc2 == 0 else "reload->restart_failed",
"rc": rc2,
"stdout": out2,
"stderr": err2 or err,
}
return {"service": name, "action": "custom_failed", "rc": rc, "stdout": out, "stderr": err}
def reloadServices():
"""Przeładowuje wszystkie serwisy z konfiguracji i zwraca listę wyników."""
svcs = listServices()
if not svcs:
logger.info("Brak skonfigurowanych serwisów do przeładowania.")
return []
results = []
for s in svcs:
res = reloadService(s)
logger.info(f"Reload {s}: action={res['action']} rc={res['rc']}")
if res["stderr"]:
logger.debug(f"{s} stderr: {res['stderr']}")
results.append(res)
return results
def maskToken(token: str | None) -> str:
if not token:
return ""
if len(token) <= 8:
return "*" * len(token)
return token[:4] + "*" * (len(token) - 8) + token[-4:]
#
# WYCZYTUJEMY TOKEN
# ------------------------------------------------------
file_token = readTokenFromFile(TOKEN_FILE_PATH)
if file_token:
API_TOKEN = file_token
logger.info(f"API_TOKEN wczytany z pliku: {TOKEN_FILE_PATH}")
else:
env_token = os.environ.get("HOSTS_DAEMON_API_TOKEN")
if env_token:
API_TOKEN = env_token
logger.info("API_TOKEN wczytany ze zmiennej środowiskowej HOSTS_DAEMON_API_TOKEN.")
else:
API_TOKEN = "superSecretTokenABC123"
logger.info("API_TOKEN ustawiony na wartość domyślną: superSecretTokenABC123")
def require_auth():
"""Wymusza autoryzację przy pomocy nagłówka Authorization,
który powinien zawierać API_TOKEN."""
token = request.headers.get("Authorization")
logger.info(f"requireAuth() -> Nagłówek Authorization: {maskToken(token)}")
if token != API_TOKEN:
logger.warning("Nieprawidłowy token w nagłówku Authorization. Oczekiwano innego ciągu znaków.")
abort(401, description="Unauthorized")
def validate_hosts_syntax(hosts_content):
import ipaddress
seen = {}
lines = hosts_content.splitlines()
for i, line in enumerate(lines, start=1):
line_strip = line.strip()
# Pomijamy puste i komentarze
if not line_strip or line_strip.startswith('#'):
continue
parts = line_strip.split()
if len(parts) < 2:
return f"Linia {i}: Za mało elementów, wymagane IP oraz co najmniej jeden hostname."
ip_addr = parts[0]
hostnames = parts[1:]
# Prosta weryfikacja IP
try:
_ = ipaddress.ip_address(ip_addr)
except ValueError:
return f"Linia {i}: '{ip_addr}' nie jest poprawnym adresem IP"
for hn in hostnames:
key = (ip_addr, hn)
if key in seen:
return f"Linia {i}: duplikat wpisu {ip_addr} -> {hn}"
seen[key] = True
return None
def writeHostsAtomic(new_content: str, path: str = "/etc/hosts", backup_dir: str | None = None) -> dict:
"""
Zapisuje plik atomowo:
- tworzy kopię zapasową (z timestampem) jeśli backup_dir podane,
- zapis do pliku tymczasowego + fsync + rename().
Zwraca info o backupie i ścieżkach.
"""
from tempfile import NamedTemporaryFile
info = {"path": path, "backup": None}
# kopia zapasowa
if backup_dir:
os.makedirs(backup_dir, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
backup_path = os.path.join(backup_dir, f"hosts.{ts}.bak")
shutil.copy2(path, backup_path)
info["backup"] = backup_path
# zapis atomowy
dir_name = os.path.dirname(path) or "."
with NamedTemporaryFile("w", dir=dir_name, delete=False, encoding="utf-8") as tmp:
tmp.write(new_content)
tmp.flush()
os.fsync(tmp.fileno())
tmp_name = tmp.name
os.replace(tmp_name, path)
return info
def computeUnifiedDiff(old_text: str, new_text: str, fromfile="/etc/hosts(old)", tofile="/etc/hosts(new)") -> str:
import difflib
return "".join(difflib.unified_diff(
old_text.splitlines(keepends=True),
new_text.splitlines(keepends=True),
fromfile=fromfile, tofile=tofile, n=3
))
# ------------------
# HOOKS LOGOWANIA / METRYK
# ------------------
@app.before_request
def before_request_logging():
request.start_time = time.time()
client_ip = request.remote_addr
endpoint = request.path
logger.info(f"Request from {client_ip} to {endpoint} [{request.method}], Auth: {maskToken(request.headers.get('Authorization'))}")
metrics["total_requests"] += 1
if endpoint not in metrics["endpoints"]:
metrics["endpoints"][endpoint] = {"count": 0, "total_time": 0.0}
metrics["endpoints"][endpoint]["count"] += 1
@app.after_request
def after_request_logging(response):
elapsed = time.time() - request.start_time
metrics["total_time"] += elapsed
endpoint = request.path
if endpoint in metrics["endpoints"]:
metrics["endpoints"][endpoint]["total_time"] += elapsed
logger.info(f"Completed {endpoint} in {elapsed:.3f} sec with status {response.status_code}")
return response
# ------------------
# ENDPOINTY
# ------------------
@app.route('/', methods=['GET'])
def root_index():
return jsonify({"info": "hosts_daemon is running. Try /health or /hosts"}), 200
@app.route('/hosts', methods=['GET'])
def get_hosts():
require_auth()
metrics["hosts_get"] += 1
try:
with open('/etc/hosts', 'r') as f:
content = f.read()
logger.info(f"/hosts GET successful from {request.remote_addr}")
return jsonify({"hosts": content})
except Exception as e:
logger.error(f"/hosts GET error: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/hosts', methods=['POST'])
def updateHosts():
requireAuth()
metrics["hosts_post"] += 1
data = request.get_json()
if not data or "hosts" not in data:
logger.warning(f"/hosts POST: missing 'hosts' key from {request.remote_addr}")
return jsonify({"error": "Invalid request, missing 'hosts' key"}), 400
newContent = data["hosts"]
errorMsg = validateHostsSyntax(newContent)
if errorMsg:
logger.error(f"/hosts POST validation error: {errorMsg}")
return jsonify({"error": errorMsg}), 400
try:
# diff (opcjonalny log)
try:
with open('/etc/hosts', 'r', encoding='utf-8') as f:
oldContent = f.read()
except Exception:
oldContent = ""
writeInfo = writeHostsAtomic(newContent, "/etc/hosts", backup_dir=os.path.join(LOG_DIR, "backups"))
logger.info(f"/etc/hosts zapisano atomowo. backup={writeInfo['backup']}")
if oldContent:
diff = computeUnifiedDiff(oldContent, newContent)
if diff:
logger.info("Diff /etc/hosts:\n" + diff)
reloadResults = reloadServices()
return jsonify({
"message": "File updated successfully",
"backup": writeInfo["backup"],
"reload": reloadResults
})
except Exception as e:
logger.error(f"/hosts POST error: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/health', methods=['GET'])
def health():
# Endpoint nie wymaga tokenu
uptime = time.time() - psutil.boot_time()
now = datetime.now(timezone.utc).isoformat()
logger.info(f"/health check from {request.remote_addr}")
return jsonify({
"status": "ok",
"time": now,
"uptime": f"{uptime:.1f} seconds"
}), 200
@app.route('/metrics', methods=['GET'])
def metrics_endpoint():
# Endpoint nie wymaga tokenu
avg_time = metrics["total_time"] / metrics["total_requests"] if metrics["total_requests"] > 0 else 0.0
ep_data = {}
for ep, data in metrics["endpoints"].items():
ep_avg = data["total_time"] / data["count"] if data["count"] > 0 else 0.0
ep_data[ep] = {"count": data["count"], "avg_time": ep_avg}
response_data = {
"total_requests": metrics["total_requests"],
"avg_response_time": avg_time,
"endpoints": ep_data,
"hosts_get": metrics.get("hosts_get", 0),
"hosts_post": metrics.get("hosts_post", 0)
}
logger.info(f"/metrics accessed by {request.remote_addr}")
return jsonify(response_data), 200
@app.route('/system-info', methods=['GET'])
def system_info():
info = {}
info["cpu_percent"] = psutil.cpu_percent(interval=0.1)
mem = psutil.virtual_memory()
info["memory_total"] = mem.total
info["memory_used"] = mem.used
info["memory_percent"] = mem.percent
disk = psutil.disk_usage('/')
info["disk_total"] = disk.total
info["disk_used"] = disk.used
info["disk_percent"] = disk.percent
dist = platform.platform()
info["platform"] = dist
sys_uptime = time.time() - psutil.boot_time()
info["uptime_seconds"] = sys_uptime
logger.info(f"/system-info accessed by {request.remote_addr}")
return jsonify(info), 200
if __name__ == '__main__':
logger.info("Uruchamiam hosts_daemon nasłuch na porcie 8000 (HTTPS).")
logger.info(f"LOG_DIR: {LOG_DIR}")
logger.info(f"TOKEN_FILE_PATH: {TOKEN_FILE_PATH}")
app.run(
host='0.0.0.0',
port=8000,
ssl_context=('ssl/hosts_daemon.crt', 'ssl/hosts_daemon.key')
)