dodatkowe poprawki i funkcje

This commit is contained in:
Mateusz Gruszczyński
2025-09-05 09:18:32 +02:00
parent 981f1e366d
commit 1f96a6e299
4 changed files with 188 additions and 66 deletions

241
app.py
View File

@@ -1,63 +1,136 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os, time, logging, platform, psutil, configparser, subprocess, shutil
import time from pathlib import Path
import logging
import difflib
import platform
import psutil
from flask import Flask, request, jsonify, abort from flask import Flask, request, jsonify, abort
from flask_sslify import SSLify from flask_sslify import SSLify
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone
#from croniter import croniter
app = Flask(__name__) app = Flask(__name__)
sslify = SSLify(app) sslify = SSLify(app)
# # --- ŚCIEŻKI / ŚRODOWISKO ---
# DYNAMICZNE ŚCIEŻKI USTAWIANE PRZEZ ZMIENNE ŚRODOWISKOWE
# ------------------------------------------------------
# - Jeśli zmienna nie jest ustawiona, używamy wartości domyślnej.
#
LOG_DIR = os.environ.get("HOSTS_DAEMON_LOG_DIR", "logs") LOG_DIR = os.environ.get("HOSTS_DAEMON_LOG_DIR", "logs")
TOKEN_FILE_PATH = os.environ.get("HOSTS_DAEMON_TOKEN_FILE", "daemon_token.txt") TOKEN_FILE_PATH = os.environ.get("HOSTS_DAEMON_TOKEN_FILE", "daemon_token.txt")
CONFIG_PATH = os.environ.get("HOSTS_DAEMON_CONFIG", "config.ini")
def read_token_from_file(path):
"""Odczytuje token z pliku i zwraca jego zawartość (strip),
albo None, jeśli plik nie istnieje lub jest pusty."""
if os.path.isfile(path):
try:
with open(path, 'r') as f:
content = f.read().strip()
if content:
return content
except Exception as e:
logger.error(f"Nie udało się odczytać pliku tokenu: {str(e)}")
return None
# Na tym etapie nie mamy jeszcze loggera, więc jego konfiguracja będzie poniżej
#
# KONFIGURACJA LOGOWANIA
# ------------------------------------------------------
# Upewniamy się, że katalog logów istnieje
os.makedirs(LOG_DIR, exist_ok=True) os.makedirs(LOG_DIR, exist_ok=True)
LOG_FILE = os.path.join(LOG_DIR, "daemon.log") LOG_FILE = os.path.join(LOG_DIR, "daemon.log")
# --- LOGGING ---
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s", format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[ handlers=[logging.FileHandler(LOG_FILE), logging.StreamHandler()]
logging.FileHandler(LOG_FILE),
logging.StreamHandler()
]
) )
logger = logging.getLogger("hosts_daemon") logger = logging.getLogger("hosts_daemon")
# --- CONFIG ---
cfg = configparser.ConfigParser()
read_ok = cfg.read(CONFIG_PATH)
if not read_ok:
logger.warning(f"Brak pliku konfiguracyjnego {CONFIG_PATH} lub nieczytelny.")
# --- METRYKI ---
metrics = {
"total_requests": 0,
"total_time": 0.0,
"endpoints": {},
"hosts_get": 0,
"hosts_post": 0,
}
# ------------------
# FUNKCJE POMOCNICZE
# ------------------
def getCfg(key: str, default=None):
"""Pobiera wartość z sekcji [daemon] lub zwraca domyślną."""
return cfg.get("daemon", key, fallback=default)
RELOAD_TIMEOUT = int(getCfg("reload_timeout", "5") or 5)
def readTokenFromFile(path: str):
p = Path(path)
try:
if not p.is_file():
return None
content = p.read_text(encoding="utf-8").strip()
return content if content else None
except (PermissionError, IsADirectoryError, OSError, UnicodeDecodeError) as e:
logger.error(f"Nie udało się odczytać pliku tokenu '{path}': {e}")
return None
def listServices():
"""Zwraca listę serwisów do przeładowania z klucza 'services'."""
raw = getCfg("services", "")
return [s.strip() for s in raw.split(",") if s.strip()]
def serviceCommand(name: str):
"""Zwraca komendę dla danego serwisu (sekcja [service:<name>] ma priorytet)."""
sect = f"service:{name}"
if cfg.has_section(sect) and cfg.has_option(sect, "command"):
return cfg.get(sect, "command")
return f"systemctl reload {name}"
def runCmd(cmd: str):
"""Uruchamia komendę w shellu z timeoutem i logowaniem."""
logger.info(f"Exec: {cmd}")
try:
out = subprocess.run(
cmd, shell=True, capture_output=True, text=True, timeout=RELOAD_TIMEOUT
)
stdout = (out.stdout or "").strip()
stderr = (out.stderr or "").strip()
return out.returncode, stdout, stderr
except subprocess.TimeoutExpired:
return 124, "", "Timeout"
def reloadService(name: str):
"""Próbuje reload; gdy nie powiedzie się standardowe 'systemctl reload', robi restart."""
cmd = serviceCommand(name)
rc, out, err = runCmd(cmd)
if rc == 0:
return {"service": name, "action": "reload", "rc": rc, "stdout": out, "stderr": err}
if cmd.startswith("systemctl reload"):
rc2, out2, err2 = runCmd(f"systemctl restart {name}")
return {
"service": name,
"action": "restart" if rc2 == 0 else "reload->restart_failed",
"rc": rc2,
"stdout": out2,
"stderr": err2 or err,
}
return {"service": name, "action": "custom_failed", "rc": rc, "stdout": out, "stderr": err}
def reloadServices():
"""Przeładowuje wszystkie serwisy z konfiguracji i zwraca listę wyników."""
svcs = listServices()
if not svcs:
logger.info("Brak skonfigurowanych serwisów do przeładowania.")
return []
results = []
for s in svcs:
res = reloadService(s)
logger.info(f"Reload {s}: action={res['action']} rc={res['rc']}")
if res["stderr"]:
logger.debug(f"{s} stderr: {res['stderr']}")
results.append(res)
return results
def maskToken(token: str | None) -> str:
if not token:
return ""
if len(token) <= 8:
return "*" * len(token)
return token[:4] + "*" * (len(token) - 8) + token[-4:]
# #
# WYCZYTUJEMY TOKEN # WYCZYTUJEMY TOKEN
# ------------------------------------------------------ # ------------------------------------------------------
file_token = read_token_from_file(TOKEN_FILE_PATH) file_token = readTokenFromFile(TOKEN_FILE_PATH)
if file_token: if file_token:
API_TOKEN = file_token API_TOKEN = file_token
logger.info(f"API_TOKEN wczytany z pliku: {TOKEN_FILE_PATH}") logger.info(f"API_TOKEN wczytany z pliku: {TOKEN_FILE_PATH}")
@@ -70,25 +143,11 @@ else:
API_TOKEN = "superSecretTokenABC123" API_TOKEN = "superSecretTokenABC123"
logger.info("API_TOKEN ustawiony na wartość domyślną: superSecretTokenABC123") logger.info("API_TOKEN ustawiony na wartość domyślną: superSecretTokenABC123")
# Globalne metryki
metrics = {
"total_requests": 0,
"total_time": 0.0,
"endpoints": {},
"hosts_get": 0,
"hosts_post": 0,
}
# ------------------
# FUNKCJE POMOCNICZE
# ------------------
def require_auth(): def require_auth():
"""Wymusza autoryzację przy pomocy nagłówka Authorization, """Wymusza autoryzację przy pomocy nagłówka Authorization,
który powinien zawierać API_TOKEN.""" który powinien zawierać API_TOKEN."""
token = request.headers.get("Authorization") token = request.headers.get("Authorization")
logger.info(f"require_auth() -> Nagłówek Authorization: {token}") logger.info(f"requireAuth() -> Nagłówek Authorization: {maskToken(token)}")
if token != API_TOKEN: if token != API_TOKEN:
logger.warning("Nieprawidłowy token w nagłówku Authorization. Oczekiwano innego ciągu znaków.") logger.warning("Nieprawidłowy token w nagłówku Authorization. Oczekiwano innego ciągu znaków.")
abort(401, description="Unauthorized") abort(401, description="Unauthorized")
@@ -123,6 +182,40 @@ def validate_hosts_syntax(hosts_content):
seen[key] = True seen[key] = True
return None return None
def writeHostsAtomic(new_content: str, path: str = "/etc/hosts", backup_dir: str | None = None) -> dict:
"""
Zapisuje plik atomowo:
- tworzy kopię zapasową (z timestampem) jeśli backup_dir podane,
- zapis do pliku tymczasowego + fsync + rename().
Zwraca info o backupie i ścieżkach.
"""
from tempfile import NamedTemporaryFile
info = {"path": path, "backup": None}
# kopia zapasowa
if backup_dir:
os.makedirs(backup_dir, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
backup_path = os.path.join(backup_dir, f"hosts.{ts}.bak")
shutil.copy2(path, backup_path)
info["backup"] = backup_path
# zapis atomowy
dir_name = os.path.dirname(path) or "."
with NamedTemporaryFile("w", dir=dir_name, delete=False, encoding="utf-8") as tmp:
tmp.write(new_content)
tmp.flush()
os.fsync(tmp.fileno())
tmp_name = tmp.name
os.replace(tmp_name, path)
return info
def computeUnifiedDiff(old_text: str, new_text: str, fromfile="/etc/hosts(old)", tofile="/etc/hosts(new)") -> str:
import difflib
return "".join(difflib.unified_diff(
old_text.splitlines(keepends=True),
new_text.splitlines(keepends=True),
fromfile=fromfile, tofile=tofile, n=3
))
# ------------------ # ------------------
# HOOKS LOGOWANIA / METRYK # HOOKS LOGOWANIA / METRYK
@@ -132,7 +225,7 @@ def before_request_logging():
request.start_time = time.time() request.start_time = time.time()
client_ip = request.remote_addr client_ip = request.remote_addr
endpoint = request.path endpoint = request.path
logger.info(f"Request from {client_ip} to {endpoint} [{request.method}], Auth: {request.headers.get('Authorization')}") logger.info(f"Request from {client_ip} to {endpoint} [{request.method}], Auth: {maskToken(request.headers.get('Authorization'))}")
metrics["total_requests"] += 1 metrics["total_requests"] += 1
if endpoint not in metrics["endpoints"]: if endpoint not in metrics["endpoints"]:
metrics["endpoints"][endpoint] = {"count": 0, "total_time": 0.0} metrics["endpoints"][endpoint] = {"count": 0, "total_time": 0.0}
@@ -171,29 +264,49 @@ def get_hosts():
return jsonify({"error": str(e)}), 500 return jsonify({"error": str(e)}), 500
@app.route('/hosts', methods=['POST']) @app.route('/hosts', methods=['POST'])
def update_hosts(): def updateHosts():
require_auth() requireAuth()
metrics["hosts_post"] += 1 metrics["hosts_post"] += 1
data = request.get_json() data = request.get_json()
if not data or "hosts" not in data: if not data or "hosts" not in data:
logger.warning(f"/hosts POST: missing 'hosts' key from {request.remote_addr}") logger.warning(f"/hosts POST: missing 'hosts' key from {request.remote_addr}")
return jsonify({"error": "Invalid request, missing 'hosts' key"}), 400 return jsonify({"error": "Invalid request, missing 'hosts' key"}), 400
new_content = data["hosts"] newContent = data["hosts"]
error_msg = validate_hosts_syntax(new_content) errorMsg = validateHostsSyntax(newContent)
if error_msg: if errorMsg:
logger.error(f"/hosts POST validation error: {error_msg}") logger.error(f"/hosts POST validation error: {errorMsg}")
return jsonify({"error": error_msg}), 400 return jsonify({"error": errorMsg}), 400
try: try:
with open('/etc/hosts', 'w') as f: # diff (opcjonalny log)
f.write(new_content) try:
logger.info(f"/hosts POST updated by {request.remote_addr}") with open('/etc/hosts', 'r', encoding='utf-8') as f:
return jsonify({"message": "File updated successfully"}) oldContent = f.read()
except Exception:
oldContent = ""
writeInfo = writeHostsAtomic(newContent, "/etc/hosts", backup_dir=os.path.join(LOG_DIR, "backups"))
logger.info(f"/etc/hosts zapisano atomowo. backup={writeInfo['backup']}")
if oldContent:
diff = computeUnifiedDiff(oldContent, newContent)
if diff:
logger.info("Diff /etc/hosts:\n" + diff)
reloadResults = reloadServices()
return jsonify({
"message": "File updated successfully",
"backup": writeInfo["backup"],
"reload": reloadResults
})
except Exception as e: except Exception as e:
logger.error(f"/hosts POST error: {str(e)}") logger.error(f"/hosts POST error: {str(e)}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": str(e)}), 500
@app.route('/health', methods=['GET']) @app.route('/health', methods=['GET'])
def health(): def health():
# Endpoint nie wymaga tokenu # Endpoint nie wymaga tokenu

10
config.example.ini Normal file
View File

@@ -0,0 +1,10 @@
[daemon]
API_TOKEN = apitoken
services = dnsmasq
# opcjonalnie: globalny timeout (sekundy)
reload_timeout = 5
[service:dnsmasq]
command = systemctl reload dnsmasq

View File

@@ -1,2 +0,0 @@
[daemon]
API_TOKEN = superSecretTokenABC123

View File

@@ -3,3 +3,4 @@ Flask-SSLify
tzlocal tzlocal
gunicorn gunicorn
psutil psutil
croniter