dane w headerach i inne funkcje

This commit is contained in:
Mateusz Gruszczyński
2025-10-09 16:40:56 +02:00
parent cb109b63ae
commit eb137c87b0
7 changed files with 161 additions and 76 deletions

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, Request, Depends, HTTPException, status from fastapi import APIRouter, Request, Depends, HTTPException, status, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from .deps import get_geo from .deps import get_geo
from .config import settings from .config import settings
@@ -11,37 +11,47 @@ router = APIRouter()
security = HTTPBasic() security = HTTPBasic()
VENDOR_SINGLE_IP_HEADERS = [ VENDOR_SINGLE_IP_HEADERS = [
"cf-connecting-ip", # Cloudflare "cf-connecting-ip", # Cloudflare
"true-client-ip", # Akamai/F5 "true-client-ip", # Akamai/F5
"x-cluster-client-ip", # niektóre load balancery "x-cluster-client-ip", # niektóre load balancery
"x-real-ip", # klasyk (nginx/traefik) "x-real-ip", # klasyk (nginx/traefik)
] ]
def _check_admin(creds: HTTPBasicCredentials): def _check_admin(creds: HTTPBasicCredentials):
user = settings.admin_user user = settings.admin_user
pwd = settings.admin_pass pwd = settings.admin_pass
if not user or not pwd: if not user or not pwd:
raise HTTPException(status_code=403, detail='admin credentials not configured') raise HTTPException(status_code=403, detail="admin credentials not configured")
# constant-time compare # constant-time compare
if not (secrets.compare_digest(creds.username, user) and secrets.compare_digest(creds.password, pwd)): if not (
raise HTTPException(status_code=401, detail='invalid credentials', headers={"WWW-Authenticate":"Basic"}) secrets.compare_digest(creds.username, user)
and secrets.compare_digest(creds.password, pwd)
):
raise HTTPException(
status_code=401,
detail="invalid credentials",
headers={"WWW-Authenticate": "Basic"},
)
return True return True
def _normalize_ip_str(ip_raw: str) -> str | None: def _normalize_ip_str(ip_raw: str) -> str | None:
"""Usuń port, whitespace i ewentualne cudzysłowy""" """Usuń port, whitespace i ewentualne cudzysłowy"""
if not ip_raw: if not ip_raw:
return None return None
ip_raw = ip_raw.strip().strip('"').strip("'") ip_raw = ip_raw.strip().strip('"').strip("'")
# usuń port, np. 1.2.3.4:5678 # usuń port, np. 1.2.3.4:5678
if ':' in ip_raw and ip_raw.count(':') == 1: if ":" in ip_raw and ip_raw.count(":") == 1:
# prawdopodobnie IPv4:port # prawdopodobnie IPv4:port
ip_raw = ip_raw.split(':')[0] ip_raw = ip_raw.split(":")[0]
# Pozostaw kwestie IPv6 z %zone # Pozostaw kwestie IPv6 z %zone
return ip_raw return ip_raw
def _is_ip_trusted(ip_str: str) -> bool: def _is_ip_trusted(ip_str: str) -> bool:
try: try:
ip = ipaddress.ip_address(ip_str.split('%')[0]) ip = ipaddress.ip_address(ip_str.split("%")[0])
except Exception: except Exception:
return False return False
for net in settings.trusted_proxies: for net in settings.trusted_proxies:
@@ -52,17 +62,32 @@ def _is_ip_trusted(ip_str: str) -> bool:
continue continue
return False return False
def _extract_from_forwarded(header_value: str) -> list[str]: def _extract_from_forwarded(header_value: str) -> list[str]:
# Forwarded: for=192.0.2.43, for="[2001:db8:cafe::17]";proto=http;by=... # Forwarded: for=192.0.2.43, for="[2001:db8:cafe::17]";proto=http;by=...
ips = [] ips = []
parts = re.split(r',\s*(?=[fF]or=)', header_value) parts = re.split(r",\s*(?=[fF]or=)", header_value)
for part in parts: for part in parts:
m = re.search(r'for=(?P<val>"[^"]+"|[^;,\s]+)', part) m = re.search(r'for=(?P<val>"[^"]+"|[^;,\s]+)', part)
if m: if m:
val = m.group('val').strip('"').strip("'") val = m.group("val").strip('"').strip("'")
ips.append(val) ips.append(val)
return ips return ips
def geo_headers(data: dict) -> dict:
h = {}
country = data.get("country", {}).get("name") if data.get("country") else None
city = data.get("city")
ip_val = data.get("ip")
if ip_val and country:
h["X-IP-ADDRESS"] = ip_val
h["X-COUNTRY"] = country
if city:
h["X-CITY"] = city
return h
def get_client_ip(request: Request) -> str: def get_client_ip(request: Request) -> str:
""" """
Zwraca IP klienta biorąc pod uwagę: Zwraca IP klienta biorąc pod uwagę:
@@ -127,48 +152,44 @@ def get_client_ip(request: Request) -> str:
try: try:
host = request.client.host host = request.client.host
if host: if host:
return host.split('%')[0] if '%' in host else host return host.split("%")[0] if "%" in host else host
except Exception: except Exception:
pass pass
return "0.0.0.0" return "0.0.0.0"
@router.get('/ip')
@router.get("/ip")
async def my_ip(request: Request, geo=Depends(get_geo)): async def my_ip(request: Request, geo=Depends(get_geo)):
ip = get_client_ip(request) ip = get_client_ip(request)
# handle IPv6 mapped IPv4 like ::ffff:1.2.3.4 data = geo.lookup(ip)
try: return Response(
ip = ip.split('%')[0] content=data.__str__(), media_type="application/json", headers=geo_headers(data)
except Exception: )
pass
return geo.lookup(ip)
@router.get('/ip/{ip_address}')
@router.get("/ip/{ip_address}")
async def ip_lookup(ip_address: str, geo=Depends(get_geo)): async def ip_lookup(ip_address: str, geo=Depends(get_geo)):
# validate IP data = geo.lookup(ip_address)
try: return Response(
# allow zone index for IPv6 and strip it for validation content=data.__str__(), media_type="application/json", headers=geo_headers(data)
if '%' in ip_address: )
addr = ip_address.split('%')[0]
else:
addr = ip_address
ipaddress.ip_address(addr)
except Exception:
raise HTTPException(status_code=400, detail='invalid IP address')
return geo.lookup(ip_address)
@router.post('/reload')
@router.post("/reload")
async def reload(creds: HTTPBasicCredentials = Depends(security)): async def reload(creds: HTTPBasicCredentials = Depends(security)):
_check_admin(creds) _check_admin(creds)
provider = reload_provider() provider = reload_provider()
return {'reloaded': True, 'provider': type(provider).__name__} return {"reloaded": True, "provider": type(provider).__name__}
@router.get('/health')
@router.get("/health")
async def health(): async def health():
return {'status':'ok'} return {"status": "ok"}
#from fastapi import Request
#@router.get("/_debug/headers") # from fastapi import Request
#async def debug_headers(request: Request):
# return {"headers": dict(request.headers)} # @router.get("/_debug/headers")
# async def debug_headers(request: Request):
# return {"headers": dict(request.headers)}

View File

@@ -5,6 +5,7 @@ import ipaddress
load_dotenv() load_dotenv()
def _parse_trusted_proxies(raw: str): def _parse_trusted_proxies(raw: str):
# raw: comma-separated list of IPs or CIDR ranges # raw: comma-separated list of IPs or CIDR ranges
items = [p.strip() for p in (raw or "").split(",") if p.strip()] items = [p.strip() for p in (raw or "").split(",") if p.strip()]
@@ -16,47 +17,54 @@ def _parse_trusted_proxies(raw: str):
else: else:
# treat single IP as /32 or /128 network # treat single IP as /32 or /128 network
ip = ipaddress.ip_address(p) ip = ipaddress.ip_address(p)
nets.append(ipaddress.ip_network(ip.exploded + ("/32" if ip.version == 4 else "/128"))) nets.append(
ipaddress.ip_network(
ip.exploded + ("/32" if ip.version == 4 else "/128")
)
)
except Exception: except Exception:
# ignoruj błędne wpisy # ignoruj błędne wpisy
continue continue
return nets return nets
class Settings(BaseSettings): class Settings(BaseSettings):
geo_provider: str = os.getenv('GEO_PROVIDER', 'maxmind') geo_provider: str = os.getenv("GEO_PROVIDER", "maxmind")
# MaxMind # MaxMind
maxmind_account_id: str | None = os.getenv('MAXMIND_ACCOUNT_ID') maxmind_account_id: str | None = os.getenv("MAXMIND_ACCOUNT_ID")
maxmind_license_key: str | None = os.getenv('MAXMIND_LICENSE_KEY') maxmind_license_key: str | None = os.getenv("MAXMIND_LICENSE_KEY")
maxmind_db_name: str = os.getenv('MAXMIND_DB_NAME', 'GeoLite2-City') maxmind_db_name: str = os.getenv("MAXMIND_DB_NAME", "GeoLite2-City")
maxmind_db_path: str = os.getenv('MAXMIND_DB_PATH', '/data/GeoLite2-City.mmdb') maxmind_db_path: str = os.getenv("MAXMIND_DB_PATH", "/data/GeoLite2-City.mmdb")
maxmind_download_url_template: str | None = os.getenv( maxmind_download_url_template: str | None = os.getenv(
'MAXMIND_DOWNLOAD_URL_TEMPLATE', "MAXMIND_DOWNLOAD_URL_TEMPLATE",
'https://download.maxmind.com/app/geoip_download?edition_id={DBNAME}&license_key={LICENSE_KEY}&suffix=tar.gz' "https://download.maxmind.com/app/geoip_download?edition_id={DBNAME}&license_key={LICENSE_KEY}&suffix=tar.gz",
) )
maxmind_direct_db_url: str | None = os.getenv('MAXMIND_DIRECT_DB_URL') maxmind_direct_db_url: str | None = os.getenv("MAXMIND_DIRECT_DB_URL")
maxmind_github_repo: str | None = os.getenv('MAXMIND_GITHUB_REPO') maxmind_github_repo: str | None = os.getenv("MAXMIND_GITHUB_REPO")
github_token: str | None = os.getenv('GITHUB_TOKEN') github_token: str | None = os.getenv("GITHUB_TOKEN")
# IP2Location # IP2Location
ip2location_download_url: str | None = os.getenv('IP2LOCATION_DOWNLOAD_URL') ip2location_download_url: str | None = os.getenv("IP2LOCATION_DOWNLOAD_URL")
ip2location_db_path: str = os.getenv('IP2LOCATION_DB_PATH', '/data/IP2LOCATION.BIN') ip2location_db_path: str = os.getenv("IP2LOCATION_DB_PATH", "/data/IP2LOCATION.BIN")
update_interval_seconds: int = int(os.getenv('UPDATE_INTERVAL_SECONDS', '86400')) update_interval_seconds: int = int(os.getenv("UPDATE_INTERVAL_SECONDS", "86400"))
host: str = os.getenv('HOST', '0.0.0.0') host: str = os.getenv("HOST", "0.0.0.0")
port: int = int(os.getenv('PORT', '8000')) port: int = int(os.getenv("PORT", "8000"))
log_level: str = os.getenv('LOG_LEVEL', 'info') log_level: str = os.getenv("LOG_LEVEL", "info")
admin_user: str | None = os.getenv('ADMIN_USER') admin_user: str | None = os.getenv("ADMIN_USER")
admin_pass: str | None = os.getenv('ADMIN_PASS') admin_pass: str | None = os.getenv("ADMIN_PASS")
cache_maxsize: int = int(os.getenv('CACHE_MAXSIZE', '4096')) cache_maxsize: int = int(os.getenv("CACHE_MAXSIZE", "4096"))
# Nowe: lista zaufanych proxy (CIDR lub IP), oddzielone przecinkami # Nowe: lista zaufanych proxy (CIDR lub IP), oddzielone przecinkami
# Przykład: "127.0.0.1,10.0.0.0/8,192.168.1.5" # Przykład: "127.0.0.1,10.0.0.0/8,192.168.1.5"
_trusted_proxies_raw: str | None = os.getenv('TRUSTED_PROXIES', '') _trusted_proxies_raw: str | None = os.getenv("TRUSTED_PROXIES", "")
@property @property
def trusted_proxies(self): def trusted_proxies(self):
return _parse_trusted_proxies(self._trusted_proxies_raw) return _parse_trusted_proxies(self._trusted_proxies_raw)
settings = Settings() settings = Settings()

View File

@@ -1,6 +1,7 @@
from functools import lru_cache from functools import lru_cache
from .geo import get_provider_instance from .geo import get_provider_instance
@lru_cache() @lru_cache()
def get_geo(): def get_geo():
return get_provider_instance() return get_provider_instance()

View File

@@ -8,6 +8,7 @@ from .config import settings
try: try:
import geoip2.database import geoip2.database
from geoip2.errors import AddressNotFoundError from geoip2.errors import AddressNotFoundError
try: try:
# geoip2<5 # geoip2<5
from geoip2.errors import InvalidDatabaseError # type: ignore from geoip2.errors import InvalidDatabaseError # type: ignore
@@ -17,8 +18,10 @@ try:
except Exception as e: except Exception as e:
print("Import geoip2 failed:", e) print("Import geoip2 failed:", e)
geoip2 = None geoip2 = None
# awaryjne aliasy, aby kod dalej działał # awaryjne aliasy, aby kod dalej działał
class _TmpErr(Exception): ... class _TmpErr(Exception): ...
AddressNotFoundError = _TmpErr AddressNotFoundError = _TmpErr
InvalidDatabaseError = _TmpErr InvalidDatabaseError = _TmpErr
@@ -67,8 +70,10 @@ class MaxMindGeo(GeoLookupBase):
def _detect_db_type(self): def _detect_db_type(self):
"""Próbuje określić typ bazy na podstawie metadanych, nazwy lub próbnych zapytań.""" """Próbuje określić typ bazy na podstawie metadanych, nazwy lub próbnych zapytań."""
t = (getattr(self._reader, "metadata", None) t = (
and getattr(self._reader.metadata, "database_type", "")) or "" getattr(self._reader, "metadata", None)
and getattr(self._reader.metadata, "database_type", "")
) or ""
if t: if t:
return t.lower() return t.lower()
@@ -80,7 +85,7 @@ class MaxMindGeo(GeoLookupBase):
probes = [ probes = [
("city", self._reader.city), ("city", self._reader.city),
("country", self._reader.country), ("country", self._reader.country),
("asn", self._reader.asn) ("asn", self._reader.asn),
] ]
test_ip = "1.1.1.1" test_ip = "1.1.1.1"
for key, fn in probes: for key, fn in probes:
@@ -107,7 +112,9 @@ class MaxMindGeo(GeoLookupBase):
pass pass
self._reader = geoip2.database.Reader(self.db_path) self._reader = geoip2.database.Reader(self.db_path)
self._db_type = self._detect_db_type() self._db_type = self._detect_db_type()
print(f"[MaxMindGeo] opened {self.db_path} type={self._db_type or 'unknown'}") print(
f"[MaxMindGeo] opened {self.db_path} type={self._db_type or 'unknown'}"
)
def _lookup_inner(self, ip: str): def _lookup_inner(self, ip: str):
t = (self._db_type or "").lower() t = (self._db_type or "").lower()
@@ -117,7 +124,9 @@ class MaxMindGeo(GeoLookupBase):
"ip": ip, "ip": ip,
"asn": { "asn": {
"number": getattr(rec, "autonomous_system_number", None), "number": getattr(rec, "autonomous_system_number", None),
"organization": getattr(rec, "autonomous_system_organization", None), "organization": getattr(
rec, "autonomous_system_organization", None
),
}, },
"database_type": self._db_type, "database_type": self._db_type,
} }
@@ -145,7 +154,9 @@ class MaxMindGeo(GeoLookupBase):
"continent": getattr(rec.continent, "name", None), "continent": getattr(rec.continent, "name", None),
"database_type": self._db_type, "database_type": self._db_type,
} }
raise RuntimeError(f"Nieobsługiwany / niewykryty typ bazy: {self._db_type} (plik: {self.db_path})") raise RuntimeError(
f"Nieobsługiwany / niewykryty typ bazy: {self._db_type} (plik: {self.db_path})"
)
def lookup(self, ip: str): def lookup(self, ip: str):
if not self.is_valid_ip(ip): if not self.is_valid_ip(ip):
@@ -213,8 +224,12 @@ _provider_lock = threading.RLock()
def _create_provider(): def _create_provider():
provider = settings.geo_provider.lower() provider = settings.geo_provider.lower()
if provider == "ip2location": if provider == "ip2location":
return IP2LocationGeo(db_path=settings.ip2location_db_path, cache_maxsize=settings.cache_maxsize) return IP2LocationGeo(
return MaxMindGeo(db_path=settings.maxmind_db_path, cache_maxsize=settings.cache_maxsize) db_path=settings.ip2location_db_path, cache_maxsize=settings.cache_maxsize
)
return MaxMindGeo(
db_path=settings.maxmind_db_path, cache_maxsize=settings.cache_maxsize
)
def get_provider_instance(): def get_provider_instance():

View File

@@ -1,5 +1,6 @@
import logging import logging
class IgnoreHealthAndFavicon(logging.Filter): class IgnoreHealthAndFavicon(logging.Filter):
def __init__(self, name: str = ""): def __init__(self, name: str = ""):
super().__init__(name) super().__init__(name)

View File

@@ -1,22 +1,61 @@
from fastapi import FastAPI, Response from fastapi import FastAPI, Response
from .api import router from fastapi.middleware.base import BaseHTTPMiddleware
from fastapi.responses import Response, PlainTextResponse
from .deps import get_geo
from .api import get_client_ip, router
from .config import settings from .config import settings
import uvicorn import uvicorn
app = FastAPI(title='IP Geo API') app = FastAPI(title="IP Geo API")
app.include_router(router) app.include_router(router)
async def add_geo_headers(request, call_next):
ip = get_client_ip(request)
geo = get_geo()
data = geo.lookup(ip)
response: Response = await call_next(request)
country = data.get("country", {}).get("name") if data.get("country") else None
city = data.get("city")
ip_val = data.get("ip")
if ip_val and country:
response.headers["X-IP-ADDRESS"] = ip_val
response.headers["X-COUNTRY"] = country
if city:
response.headers["X-CITY"] = city
return response
app.add_middleware(BaseHTTPMiddleware, dispatch=add_geo_headers)
@app.get("/favicon.ico") @app.get("/favicon.ico")
async def favicon(): async def favicon():
return Response(status_code=204) return Response(status_code=204)
if __name__ == '__main__':
@app.get("/")
async def root(request: Request):
ua = request.headers.get("user-agent", "").lower()
ip = get_client_ip(request)
if any(x in ua for x in ["mozilla", "chrome", "safari", "edge", "firefox"]):
return Response(status_code=404)
return PlainTextResponse(ip)
if __name__ == "__main__":
uvicorn.run( uvicorn.run(
'app.main:app', "app.main:app",
host=settings.host, host=settings.host,
port=settings.port, port=settings.port,
log_level=settings.log_level, log_level=settings.log_level,
proxy_headers=True, proxy_headers=True,
forwarded_allow_ips="*", forwarded_allow_ips="*",
# access_log=True # access_log=True
) )

View File

@@ -1,8 +1,8 @@
fastapi fastapi
uvicorn[standard] uvicorn[standard]
geoip2 geoip2
python-dotenv python - dotenv
requests requests
IP2Location IP2Location
pydantic pydantic
pydantic-settings pydantic - settings