Files
hosts_daemon/app.py
Mateusz Gruszczyński 04d5a514b4 fixy
2025-09-05 09:48:41 +02:00

403 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import os, time, logging, platform, psutil, configparser, subprocess, shutil
from pathlib import Path
from flask import Flask, request, jsonify, abort
from flask_sslify import SSLify
from datetime import datetime, timezone
app = Flask(__name__)
sslify = SSLify(app)
# --- ŚCIEŻKI / ŚRODOWISKO ---
DEFAULT_EXEC_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
LOG_DIR = os.environ.get("HOSTS_DAEMON_LOG_DIR", "logs")
TOKEN_FILE_PATH = os.environ.get("HOSTS_DAEMON_TOKEN_FILE", "daemon_token.txt")
CONFIG_PATH = os.environ.get("HOSTS_DAEMON_CONFIG", "config.ini")
os.makedirs(LOG_DIR, exist_ok=True)
LOG_FILE = os.path.join(LOG_DIR, "daemon.log")
# --- LOGGING ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.FileHandler(LOG_FILE), logging.StreamHandler()]
)
logger = logging.getLogger("hosts_daemon")
# --- CONFIG ---
cfg = configparser.ConfigParser()
read_ok = cfg.read(CONFIG_PATH)
if not read_ok:
logger.warning(f"Brak pliku konfiguracyjnego {CONFIG_PATH} lub nieczytelny.")
# --- METRYKI ---
metrics = {
"total_requests": 0,
"total_time": 0.0,
"endpoints": {},
"hosts_get": 0,
"hosts_post": 0,
}
# ------------------
# FUNKCJE POMOCNICZE
# ------------------
def getCfg(key: str, default=None):
"""Pobiera wartość z sekcji [daemon] lub zwraca domyślną."""
return cfg.get("daemon", key, fallback=default)
RELOAD_TIMEOUT = int(getCfg("reload_timeout", "5") or 5)
def readTokenFromFile(path: str):
p = Path(path)
try:
if not p.is_file():
return None
content = p.read_text(encoding="utf-8").strip()
return content if content else None
except (PermissionError, IsADirectoryError, OSError, UnicodeDecodeError) as e:
logger.error(f"Nie udało się odczytać pliku tokenu '{path}': {e}")
return None
def listServices():
"""Zwraca listę serwisów do przeładowania z klucza 'services'."""
raw = getCfg("services", "")
return [s.strip() for s in raw.split(",") if s.strip()]
def serviceCommand(name: str):
"""Zwraca komendę dla danego serwisu (sekcja [service:<name>] ma priorytet)."""
sect = f"service:{name}"
if cfg.has_section(sect) and cfg.has_option(sect, "command"):
return cfg.get(sect, "command")
return f"systemctl reload {name}"
def runCmd(cmd: str):
logger.info(f"Exec: {cmd}")
try:
env = os.environ.copy()
env["PATH"] = getCfg("exec_path", DEFAULT_EXEC_PATH) or DEFAULT_EXEC_PATH
out = subprocess.run(
cmd, shell=True, capture_output=True, text=True,
timeout=RELOAD_TIMEOUT, env=env
)
return out.returncode, (out.stdout or "").strip(), (out.stderr or "").strip()
except subprocess.TimeoutExpired:
return 124, "", "Timeout"
def reloadServices():
"""Przeładowuje wszystkie serwisy z konfiguracji i zwraca listę wyników."""
svcs = listServices()
if not svcs:
logger.info("Brak skonfigurowanych serwisów do przeładowania.")
return []
results = []
for s in svcs:
res = reloadService(s)
logger.info(f"Reload {s}: action={res['action']} rc={res['rc']}")
if res["stderr"]:
logger.debug(f"{s} stderr: {res['stderr']}")
results.append(res)
return results
def maskToken(token: str | None) -> str:
if not token:
return ""
if len(token) <= 8:
return "*" * len(token)
return token[:4] + "*" * (len(token) - 8) + token[-4:]
def commandCandidates(name: str):
sect = f"service:{name}"
if cfg.has_section(sect) and cfg.has_option(sect, "command"):
return [cfg.get(sect, "command")]
cmds = []
if shutil.which("systemctl"):
cmds += [f"systemctl reload {name}", f"systemctl restart {name}"]
if shutil.which("service"):
cmds += [f"service {name} reload", f"service {name} restart"]
if shutil.which("rc-service"):
cmds += [f"rc-service {name} reload", f"rc-service {name} restart"]
if shutil.which("pkill"):
cmds += [f"pkill -HUP {name}"]
pid_file = cfg.get(sect, "pid_file", fallback=None)
if pid_file and os.path.isfile(pid_file) and shutil.which("kill"):
cmds += [f"kill -HUP $(cat {pid_file})"]
if shutil.which("pgrep") and shutil.which("kill"):
cmds += [f"kill -HUP $(pgrep -o {name})"]
return cmds or []
def reloadService(name: str):
for cmd in commandCandidates(name):
rc, out, err = runCmd(cmd)
# 127 = command not found; spróbuj następnego kandydata
if rc == 127 or "not found" in (err or "").lower():
continue
if rc == 0:
return {"service": name, "action": cmd, "rc": rc, "stdout": out, "stderr": err}
# jeśli komenda istnieje, ale zwróciła błąd próbuj dalej
last = {"service": name, "action": cmd, "rc": rc, "stdout": out, "stderr": err}
# jeśli nic się nie udało:
return last if 'last' in locals() else {"service": name, "action": "no-cmd", "rc": 127, "stdout": "", "stderr": "no candidate"}
#
# WYCZYTUJEMY TOKEN
# ------------------------------------------------------
file_token = readTokenFromFile(TOKEN_FILE_PATH)
if file_token:
API_TOKEN = file_token
logger.info(f"API_TOKEN wczytany z pliku: {TOKEN_FILE_PATH}")
else:
env_token = os.environ.get("HOSTS_DAEMON_API_TOKEN")
if env_token:
API_TOKEN = env_token
logger.info("API_TOKEN wczytany ze zmiennej środowiskowej HOSTS_DAEMON_API_TOKEN.")
else:
API_TOKEN = "superSecretTokenABC123"
logger.info("API_TOKEN ustawiony na wartość domyślną: superSecretTokenABC123")
def requireAuth():
"""Wymusza autoryzację przy pomocy nagłówka Authorization,
który powinien zawierać API_TOKEN."""
token = request.headers.get("Authorization")
logger.info(f"requireAuth() -> Nagłówek Authorization: {maskToken(token)}")
if token != API_TOKEN:
logger.warning("Nieprawidłowy token w nagłówku Authorization. Oczekiwano innego ciągu znaków.")
abort(401, description="Unauthorized")
def validateHostsSyntax(hosts_content):
import ipaddress
seen = {}
lines = hosts_content.splitlines()
for i, line in enumerate(lines, start=1):
line_strip = line.strip()
# Pomijamy puste i komentarze
if not line_strip or line_strip.startswith('#'):
continue
parts = line_strip.split()
if len(parts) < 2:
return f"Linia {i}: Za mało elementów, wymagane IP oraz co najmniej jeden hostname."
ip_addr = parts[0]
hostnames = parts[1:]
# Prosta weryfikacja IP
try:
_ = ipaddress.ip_address(ip_addr)
except ValueError:
return f"Linia {i}: '{ip_addr}' nie jest poprawnym adresem IP"
for hn in hostnames:
key = (ip_addr, hn)
if key in seen:
return f"Linia {i}: duplikat wpisu {ip_addr} -> {hn}"
seen[key] = True
return None
def writeHostsAtomic(new_content: str, path: str = "/etc/hosts") -> dict:
"""
Zapisuje plik atomowo:
- tworzy kopię zapasową w katalogu z config.ini (klucz backup_path),
- zapis do pliku tymczasowego + fsync + rename(),
- ustawia chmod 644 na docelowym pliku.
"""
from tempfile import NamedTemporaryFile
info = {"path": path, "backup": None}
backup_dir = getCfg("backup_path", None)
# kopia zapasowa
if backup_dir:
os.makedirs(backup_dir, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
dest_path = os.path.join(backup_dir, f"hosts.{ts}.bak")
try:
shutil.copy2(path, dest_path)
info["backup"] = dest_path
except Exception as e:
logger.warning(f"Backup nieudany: {e}")
# zapis atomowy
dir_name = os.path.dirname(path) or "."
with NamedTemporaryFile("w", dir=dir_name, delete=False, encoding="utf-8") as tmp:
tmp.write(new_content)
tmp.flush()
os.fsync(tmp.fileno())
tmp_name = tmp.name
os.replace(tmp_name, path)
# ustaw chmod 644
try:
os.chmod(path, 0o644)
except Exception as e:
logger.warning(f"Nie udało się ustawić chmod 644 na {path}: {e}")
return info
def computeUnifiedDiff(old_text: str, new_text: str, fromfile="/etc/hosts(old)", tofile="/etc/hosts(new)") -> str:
import difflib
return "".join(difflib.unified_diff(
old_text.splitlines(keepends=True),
new_text.splitlines(keepends=True),
fromfile=fromfile, tofile=tofile, n=3
))
# ------------------
# HOOKS LOGOWANIA / METRYK
# ------------------
@app.before_request
def before_request_logging():
request.start_time = time.time()
client_ip = request.remote_addr
endpoint = request.path
logger.info(f"Request from {client_ip} to {endpoint} [{request.method}], Auth: {maskToken(request.headers.get('Authorization'))}")
metrics["total_requests"] += 1
if endpoint not in metrics["endpoints"]:
metrics["endpoints"][endpoint] = {"count": 0, "total_time": 0.0}
metrics["endpoints"][endpoint]["count"] += 1
@app.after_request
def after_request_logging(response):
elapsed = time.time() - request.start_time
metrics["total_time"] += elapsed
endpoint = request.path
if endpoint in metrics["endpoints"]:
metrics["endpoints"][endpoint]["total_time"] += elapsed
logger.info(f"Completed {endpoint} in {elapsed:.3f} sec with status {response.status_code}")
return response
# ------------------
# ENDPOINTY
# ------------------
@app.route('/', methods=['GET'])
def root_index():
return jsonify({"info": "hosts_daemon is running. Try /health or /hosts"}), 200
@app.route('/hosts', methods=['GET'])
def get_hosts():
requireAuth()
metrics["hosts_get"] += 1
try:
with open('/etc/hosts', 'r') as f:
content = f.read()
logger.info(f"/hosts GET successful from {request.remote_addr}")
return jsonify({"hosts": content})
except Exception as e:
logger.error(f"/hosts GET error: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/hosts', methods=['POST'])
def updateHosts():
requireAuth()
metrics["hosts_post"] += 1
data = request.get_json()
if not data or "hosts" not in data:
logger.warning(f"/hosts POST: missing 'hosts' key from {request.remote_addr}")
return jsonify({"error": "Invalid request, missing 'hosts' key"}), 400
newContent = data["hosts"]
errorMsg = validateHostsSyntax(newContent)
if errorMsg:
logger.error(f"/hosts POST validation error: {errorMsg}")
return jsonify({"error": errorMsg}), 400
try:
# diff (opcjonalny log)
try:
with open('/etc/hosts', 'r', encoding='utf-8') as f:
oldContent = f.read()
except Exception:
oldContent = ""
writeInfo = writeHostsAtomic(newContent, "/etc/hosts", backup_dir=os.path.join(LOG_DIR, "backups"))
logger.info(f"/etc/hosts zapisano atomowo. backup={writeInfo['backup']}")
if oldContent:
diff = computeUnifiedDiff(oldContent, newContent)
if diff:
logger.info("Diff /etc/hosts:\n" + diff)
reloadResults = reloadServices()
return jsonify({
"message": "File updated successfully",
"backup": writeInfo["backup"],
"reload": reloadResults
})
except Exception as e:
logger.error(f"/hosts POST error: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/health', methods=['GET'])
def health():
# Endpoint nie wymaga tokenu
uptime = time.time() - psutil.boot_time()
now = datetime.now(timezone.utc).isoformat()
logger.info(f"/health check from {request.remote_addr}")
return jsonify({
"status": "ok",
"time": now,
"uptime": f"{uptime:.1f} seconds"
}), 200
@app.route('/metrics', methods=['GET'])
def metrics_endpoint():
# Endpoint nie wymaga tokenu
avg_time = metrics["total_time"] / metrics["total_requests"] if metrics["total_requests"] > 0 else 0.0
ep_data = {}
for ep, data in metrics["endpoints"].items():
ep_avg = data["total_time"] / data["count"] if data["count"] > 0 else 0.0
ep_data[ep] = {"count": data["count"], "avg_time": ep_avg}
response_data = {
"total_requests": metrics["total_requests"],
"avg_response_time": avg_time,
"endpoints": ep_data,
"hosts_get": metrics.get("hosts_get", 0),
"hosts_post": metrics.get("hosts_post", 0)
}
logger.info(f"/metrics accessed by {request.remote_addr}")
return jsonify(response_data), 200
@app.route('/system-info', methods=['GET'])
def system_info():
info = {}
info["cpu_percent"] = psutil.cpu_percent(interval=0.1)
mem = psutil.virtual_memory()
info["memory_total"] = mem.total
info["memory_used"] = mem.used
info["memory_percent"] = mem.percent
disk = psutil.disk_usage('/')
info["disk_total"] = disk.total
info["disk_used"] = disk.used
info["disk_percent"] = disk.percent
dist = platform.platform()
info["platform"] = dist
sys_uptime = time.time() - psutil.boot_time()
info["uptime_seconds"] = sys_uptime
logger.info(f"/system-info accessed by {request.remote_addr}")
return jsonify(info), 200
if __name__ == '__main__':
logger.info("Uruchamiam hosts_daemon nasłuch na porcie 8000 (HTTPS).")
logger.info(f"LOG_DIR: {LOG_DIR}")
logger.info(f"TOKEN_FILE_PATH: {TOKEN_FILE_PATH}")
app.run(
host='0.0.0.0',
port=8000,
ssl_context=('ssl/hosts_daemon.crt', 'ssl/hosts_daemon.key')
)