dane w headerach i inne funkcje
This commit is contained in:
99
app/api.py
99
app/api.py
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Request, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Request, Depends, HTTPException, status, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from .deps import get_geo
|
||||
from .config import settings
|
||||
@@ -11,37 +11,47 @@ router = APIRouter()
|
||||
security = HTTPBasic()
|
||||
|
||||
VENDOR_SINGLE_IP_HEADERS = [
|
||||
"cf-connecting-ip", # Cloudflare
|
||||
"true-client-ip", # Akamai/F5
|
||||
"cf-connecting-ip", # Cloudflare
|
||||
"true-client-ip", # Akamai/F5
|
||||
"x-cluster-client-ip", # niektóre load balancery
|
||||
"x-real-ip", # klasyk (nginx/traefik)
|
||||
"x-real-ip", # klasyk (nginx/traefik)
|
||||
]
|
||||
|
||||
|
||||
def _check_admin(creds: HTTPBasicCredentials):
|
||||
user = settings.admin_user
|
||||
pwd = settings.admin_pass
|
||||
if not user or not pwd:
|
||||
raise HTTPException(status_code=403, detail='admin credentials not configured')
|
||||
raise HTTPException(status_code=403, detail="admin credentials not configured")
|
||||
# constant-time compare
|
||||
if not (secrets.compare_digest(creds.username, user) and secrets.compare_digest(creds.password, pwd)):
|
||||
raise HTTPException(status_code=401, detail='invalid credentials', headers={"WWW-Authenticate":"Basic"})
|
||||
if not (
|
||||
secrets.compare_digest(creds.username, user)
|
||||
and secrets.compare_digest(creds.password, pwd)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="invalid credentials",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _normalize_ip_str(ip_raw: str) -> str | None:
|
||||
"""Usuń port, whitespace i ewentualne cudzysłowy"""
|
||||
if not ip_raw:
|
||||
return None
|
||||
ip_raw = ip_raw.strip().strip('"').strip("'")
|
||||
# usuń port, np. 1.2.3.4:5678
|
||||
if ':' in ip_raw and ip_raw.count(':') == 1:
|
||||
if ":" in ip_raw and ip_raw.count(":") == 1:
|
||||
# prawdopodobnie IPv4:port
|
||||
ip_raw = ip_raw.split(':')[0]
|
||||
ip_raw = ip_raw.split(":")[0]
|
||||
# Pozostaw kwestie IPv6 z %zone
|
||||
return ip_raw
|
||||
|
||||
|
||||
def _is_ip_trusted(ip_str: str) -> bool:
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str.split('%')[0])
|
||||
ip = ipaddress.ip_address(ip_str.split("%")[0])
|
||||
except Exception:
|
||||
return False
|
||||
for net in settings.trusted_proxies:
|
||||
@@ -52,17 +62,32 @@ def _is_ip_trusted(ip_str: str) -> bool:
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
def _extract_from_forwarded(header_value: str) -> list[str]:
|
||||
# Forwarded: for=192.0.2.43, for="[2001:db8:cafe::17]";proto=http;by=...
|
||||
ips = []
|
||||
parts = re.split(r',\s*(?=[fF]or=)', header_value)
|
||||
parts = re.split(r",\s*(?=[fF]or=)", header_value)
|
||||
for part in parts:
|
||||
m = re.search(r'for=(?P<val>"[^"]+"|[^;,\s]+)', part)
|
||||
if m:
|
||||
val = m.group('val').strip('"').strip("'")
|
||||
val = m.group("val").strip('"').strip("'")
|
||||
ips.append(val)
|
||||
return ips
|
||||
|
||||
|
||||
def geo_headers(data: dict) -> dict:
|
||||
h = {}
|
||||
country = data.get("country", {}).get("name") if data.get("country") else None
|
||||
city = data.get("city")
|
||||
ip_val = data.get("ip")
|
||||
if ip_val and country:
|
||||
h["X-IP-ADDRESS"] = ip_val
|
||||
h["X-COUNTRY"] = country
|
||||
if city:
|
||||
h["X-CITY"] = city
|
||||
return h
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""
|
||||
Zwraca IP klienta biorąc pod uwagę:
|
||||
@@ -127,48 +152,44 @@ def get_client_ip(request: Request) -> str:
|
||||
try:
|
||||
host = request.client.host
|
||||
if host:
|
||||
return host.split('%')[0] if '%' in host else host
|
||||
return host.split("%")[0] if "%" in host else host
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "0.0.0.0"
|
||||
|
||||
@router.get('/ip')
|
||||
|
||||
@router.get("/ip")
|
||||
async def my_ip(request: Request, geo=Depends(get_geo)):
|
||||
ip = get_client_ip(request)
|
||||
# handle IPv6 mapped IPv4 like ::ffff:1.2.3.4
|
||||
try:
|
||||
ip = ip.split('%')[0]
|
||||
except Exception:
|
||||
pass
|
||||
return geo.lookup(ip)
|
||||
data = geo.lookup(ip)
|
||||
return Response(
|
||||
content=data.__str__(), media_type="application/json", headers=geo_headers(data)
|
||||
)
|
||||
|
||||
@router.get('/ip/{ip_address}')
|
||||
|
||||
@router.get("/ip/{ip_address}")
|
||||
async def ip_lookup(ip_address: str, geo=Depends(get_geo)):
|
||||
# validate IP
|
||||
try:
|
||||
# allow zone index for IPv6 and strip it for validation
|
||||
if '%' in ip_address:
|
||||
addr = ip_address.split('%')[0]
|
||||
else:
|
||||
addr = ip_address
|
||||
ipaddress.ip_address(addr)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail='invalid IP address')
|
||||
return geo.lookup(ip_address)
|
||||
data = geo.lookup(ip_address)
|
||||
return Response(
|
||||
content=data.__str__(), media_type="application/json", headers=geo_headers(data)
|
||||
)
|
||||
|
||||
@router.post('/reload')
|
||||
|
||||
@router.post("/reload")
|
||||
async def reload(creds: HTTPBasicCredentials = Depends(security)):
|
||||
_check_admin(creds)
|
||||
provider = reload_provider()
|
||||
return {'reloaded': True, 'provider': type(provider).__name__}
|
||||
return {"reloaded": True, "provider": type(provider).__name__}
|
||||
|
||||
@router.get('/health')
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
return {'status':'ok'}
|
||||
return {"status": "ok"}
|
||||
|
||||
#from fastapi import Request
|
||||
|
||||
#@router.get("/_debug/headers")
|
||||
#async def debug_headers(request: Request):
|
||||
# from fastapi import Request
|
||||
|
||||
# @router.get("/_debug/headers")
|
||||
# async def debug_headers(request: Request):
|
||||
# return {"headers": dict(request.headers)}
|
@@ -5,6 +5,7 @@ import ipaddress
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _parse_trusted_proxies(raw: str):
|
||||
# raw: comma-separated list of IPs or CIDR ranges
|
||||
items = [p.strip() for p in (raw or "").split(",") if p.strip()]
|
||||
@@ -16,47 +17,54 @@ def _parse_trusted_proxies(raw: str):
|
||||
else:
|
||||
# treat single IP as /32 or /128 network
|
||||
ip = ipaddress.ip_address(p)
|
||||
nets.append(ipaddress.ip_network(ip.exploded + ("/32" if ip.version == 4 else "/128")))
|
||||
nets.append(
|
||||
ipaddress.ip_network(
|
||||
ip.exploded + ("/32" if ip.version == 4 else "/128")
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# ignoruj błędne wpisy
|
||||
continue
|
||||
return nets
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
geo_provider: str = os.getenv('GEO_PROVIDER', 'maxmind')
|
||||
geo_provider: str = os.getenv("GEO_PROVIDER", "maxmind")
|
||||
|
||||
# MaxMind
|
||||
maxmind_account_id: str | None = os.getenv('MAXMIND_ACCOUNT_ID')
|
||||
maxmind_license_key: str | None = os.getenv('MAXMIND_LICENSE_KEY')
|
||||
maxmind_db_name: str = os.getenv('MAXMIND_DB_NAME', 'GeoLite2-City')
|
||||
maxmind_db_path: str = os.getenv('MAXMIND_DB_PATH', '/data/GeoLite2-City.mmdb')
|
||||
maxmind_account_id: str | None = os.getenv("MAXMIND_ACCOUNT_ID")
|
||||
maxmind_license_key: str | None = os.getenv("MAXMIND_LICENSE_KEY")
|
||||
maxmind_db_name: str = os.getenv("MAXMIND_DB_NAME", "GeoLite2-City")
|
||||
maxmind_db_path: str = os.getenv("MAXMIND_DB_PATH", "/data/GeoLite2-City.mmdb")
|
||||
maxmind_download_url_template: str | None = os.getenv(
|
||||
'MAXMIND_DOWNLOAD_URL_TEMPLATE',
|
||||
'https://download.maxmind.com/app/geoip_download?edition_id={DBNAME}&license_key={LICENSE_KEY}&suffix=tar.gz'
|
||||
"MAXMIND_DOWNLOAD_URL_TEMPLATE",
|
||||
"https://download.maxmind.com/app/geoip_download?edition_id={DBNAME}&license_key={LICENSE_KEY}&suffix=tar.gz",
|
||||
)
|
||||
maxmind_direct_db_url: str | None = os.getenv('MAXMIND_DIRECT_DB_URL')
|
||||
maxmind_github_repo: str | None = os.getenv('MAXMIND_GITHUB_REPO')
|
||||
github_token: str | None = os.getenv('GITHUB_TOKEN')
|
||||
maxmind_direct_db_url: str | None = os.getenv("MAXMIND_DIRECT_DB_URL")
|
||||
maxmind_github_repo: str | None = os.getenv("MAXMIND_GITHUB_REPO")
|
||||
github_token: str | None = os.getenv("GITHUB_TOKEN")
|
||||
|
||||
# IP2Location
|
||||
ip2location_download_url: str | None = os.getenv('IP2LOCATION_DOWNLOAD_URL')
|
||||
ip2location_db_path: str = os.getenv('IP2LOCATION_DB_PATH', '/data/IP2LOCATION.BIN')
|
||||
ip2location_download_url: str | None = os.getenv("IP2LOCATION_DOWNLOAD_URL")
|
||||
ip2location_db_path: str = os.getenv("IP2LOCATION_DB_PATH", "/data/IP2LOCATION.BIN")
|
||||
|
||||
update_interval_seconds: int = int(os.getenv('UPDATE_INTERVAL_SECONDS', '86400'))
|
||||
update_interval_seconds: int = int(os.getenv("UPDATE_INTERVAL_SECONDS", "86400"))
|
||||
|
||||
host: str = os.getenv('HOST', '0.0.0.0')
|
||||
port: int = int(os.getenv('PORT', '8000'))
|
||||
log_level: str = os.getenv('LOG_LEVEL', 'info')
|
||||
host: str = os.getenv("HOST", "0.0.0.0")
|
||||
port: int = int(os.getenv("PORT", "8000"))
|
||||
log_level: str = os.getenv("LOG_LEVEL", "info")
|
||||
|
||||
admin_user: str | None = os.getenv('ADMIN_USER')
|
||||
admin_pass: str | None = os.getenv('ADMIN_PASS')
|
||||
cache_maxsize: int = int(os.getenv('CACHE_MAXSIZE', '4096'))
|
||||
admin_user: str | None = os.getenv("ADMIN_USER")
|
||||
admin_pass: str | None = os.getenv("ADMIN_PASS")
|
||||
cache_maxsize: int = int(os.getenv("CACHE_MAXSIZE", "4096"))
|
||||
|
||||
# Nowe: lista zaufanych proxy (CIDR lub IP), oddzielone przecinkami
|
||||
# Przykład: "127.0.0.1,10.0.0.0/8,192.168.1.5"
|
||||
_trusted_proxies_raw: str | None = os.getenv('TRUSTED_PROXIES', '')
|
||||
_trusted_proxies_raw: str | None = os.getenv("TRUSTED_PROXIES", "")
|
||||
|
||||
@property
|
||||
def trusted_proxies(self):
|
||||
return _parse_trusted_proxies(self._trusted_proxies_raw)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from functools import lru_cache
|
||||
from .geo import get_provider_instance
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_geo():
|
||||
return get_provider_instance()
|
||||
|
31
app/geo.py
31
app/geo.py
@@ -8,6 +8,7 @@ from .config import settings
|
||||
try:
|
||||
import geoip2.database
|
||||
from geoip2.errors import AddressNotFoundError
|
||||
|
||||
try:
|
||||
# geoip2<5
|
||||
from geoip2.errors import InvalidDatabaseError # type: ignore
|
||||
@@ -17,8 +18,10 @@ try:
|
||||
except Exception as e:
|
||||
print("Import geoip2 failed:", e)
|
||||
geoip2 = None
|
||||
|
||||
# awaryjne aliasy, aby kod dalej działał
|
||||
class _TmpErr(Exception): ...
|
||||
|
||||
AddressNotFoundError = _TmpErr
|
||||
InvalidDatabaseError = _TmpErr
|
||||
|
||||
@@ -67,8 +70,10 @@ class MaxMindGeo(GeoLookupBase):
|
||||
|
||||
def _detect_db_type(self):
|
||||
"""Próbuje określić typ bazy na podstawie metadanych, nazwy lub próbnych zapytań."""
|
||||
t = (getattr(self._reader, "metadata", None)
|
||||
and getattr(self._reader.metadata, "database_type", "")) or ""
|
||||
t = (
|
||||
getattr(self._reader, "metadata", None)
|
||||
and getattr(self._reader.metadata, "database_type", "")
|
||||
) or ""
|
||||
if t:
|
||||
return t.lower()
|
||||
|
||||
@@ -80,7 +85,7 @@ class MaxMindGeo(GeoLookupBase):
|
||||
probes = [
|
||||
("city", self._reader.city),
|
||||
("country", self._reader.country),
|
||||
("asn", self._reader.asn)
|
||||
("asn", self._reader.asn),
|
||||
]
|
||||
test_ip = "1.1.1.1"
|
||||
for key, fn in probes:
|
||||
@@ -107,7 +112,9 @@ class MaxMindGeo(GeoLookupBase):
|
||||
pass
|
||||
self._reader = geoip2.database.Reader(self.db_path)
|
||||
self._db_type = self._detect_db_type()
|
||||
print(f"[MaxMindGeo] opened {self.db_path} type={self._db_type or 'unknown'}")
|
||||
print(
|
||||
f"[MaxMindGeo] opened {self.db_path} type={self._db_type or 'unknown'}"
|
||||
)
|
||||
|
||||
def _lookup_inner(self, ip: str):
|
||||
t = (self._db_type or "").lower()
|
||||
@@ -117,7 +124,9 @@ class MaxMindGeo(GeoLookupBase):
|
||||
"ip": ip,
|
||||
"asn": {
|
||||
"number": getattr(rec, "autonomous_system_number", None),
|
||||
"organization": getattr(rec, "autonomous_system_organization", None),
|
||||
"organization": getattr(
|
||||
rec, "autonomous_system_organization", None
|
||||
),
|
||||
},
|
||||
"database_type": self._db_type,
|
||||
}
|
||||
@@ -145,7 +154,9 @@ class MaxMindGeo(GeoLookupBase):
|
||||
"continent": getattr(rec.continent, "name", None),
|
||||
"database_type": self._db_type,
|
||||
}
|
||||
raise RuntimeError(f"Nieobsługiwany / niewykryty typ bazy: {self._db_type} (plik: {self.db_path})")
|
||||
raise RuntimeError(
|
||||
f"Nieobsługiwany / niewykryty typ bazy: {self._db_type} (plik: {self.db_path})"
|
||||
)
|
||||
|
||||
def lookup(self, ip: str):
|
||||
if not self.is_valid_ip(ip):
|
||||
@@ -213,8 +224,12 @@ _provider_lock = threading.RLock()
|
||||
def _create_provider():
|
||||
provider = settings.geo_provider.lower()
|
||||
if provider == "ip2location":
|
||||
return IP2LocationGeo(db_path=settings.ip2location_db_path, cache_maxsize=settings.cache_maxsize)
|
||||
return MaxMindGeo(db_path=settings.maxmind_db_path, cache_maxsize=settings.cache_maxsize)
|
||||
return IP2LocationGeo(
|
||||
db_path=settings.ip2location_db_path, cache_maxsize=settings.cache_maxsize
|
||||
)
|
||||
return MaxMindGeo(
|
||||
db_path=settings.maxmind_db_path, cache_maxsize=settings.cache_maxsize
|
||||
)
|
||||
|
||||
|
||||
def get_provider_instance():
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
|
||||
class IgnoreHealthAndFavicon(logging.Filter):
|
||||
def __init__(self, name: str = ""):
|
||||
super().__init__(name)
|
||||
|
47
app/main.py
47
app/main.py
@@ -1,18 +1,57 @@
|
||||
from fastapi import FastAPI, Response
|
||||
from .api import router
|
||||
from fastapi.middleware.base import BaseHTTPMiddleware
|
||||
from fastapi.responses import Response, PlainTextResponse
|
||||
from .deps import get_geo
|
||||
from .api import get_client_ip, router
|
||||
from .config import settings
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI(title='IP Geo API')
|
||||
app = FastAPI(title="IP Geo API")
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
async def add_geo_headers(request, call_next):
|
||||
ip = get_client_ip(request)
|
||||
geo = get_geo()
|
||||
data = geo.lookup(ip)
|
||||
|
||||
response: Response = await call_next(request)
|
||||
|
||||
country = data.get("country", {}).get("name") if data.get("country") else None
|
||||
city = data.get("city")
|
||||
ip_val = data.get("ip")
|
||||
|
||||
if ip_val and country:
|
||||
response.headers["X-IP-ADDRESS"] = ip_val
|
||||
response.headers["X-COUNTRY"] = country
|
||||
if city:
|
||||
response.headers["X-CITY"] = city
|
||||
|
||||
return response
|
||||
|
||||
|
||||
app.add_middleware(BaseHTTPMiddleware, dispatch=add_geo_headers)
|
||||
|
||||
|
||||
@app.get("/favicon.ico")
|
||||
async def favicon():
|
||||
return Response(status_code=204)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@app.get("/")
|
||||
async def root(request: Request):
|
||||
ua = request.headers.get("user-agent", "").lower()
|
||||
ip = get_client_ip(request)
|
||||
|
||||
if any(x in ua for x in ["mozilla", "chrome", "safari", "edge", "firefox"]):
|
||||
return Response(status_code=404)
|
||||
|
||||
return PlainTextResponse(ip)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
'app.main:app',
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
log_level=settings.log_level,
|
||||
|
@@ -1,8 +1,8 @@
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
geoip2
|
||||
python-dotenv
|
||||
python - dotenv
|
||||
requests
|
||||
IP2Location
|
||||
pydantic
|
||||
pydantic-settings
|
||||
pydantic - settings
|
||||
|
Reference in New Issue
Block a user