72 lines
2.0 KiB
Python
72 lines
2.0 KiB
Python
from fastapi import FastAPI, Request, Response
|
|
from fastapi.responses import JSONResponse, PlainTextResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from .deps import get_geo
|
|
from .api import get_client_ip, router, geo_headers
|
|
from .config import settings
|
|
import uvicorn
|
|
|
|
app = FastAPI(title="IP Geo API")
|
|
app.include_router(router)
|
|
|
|
|
|
async def add_geo_headers(request, call_next):
|
|
import unicodedata
|
|
from starlette.responses import Response # tylko dla typu
|
|
|
|
def _ascii(s: str) -> str:
|
|
try:
|
|
s.encode("latin-1")
|
|
return s
|
|
except UnicodeEncodeError:
|
|
return unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii") or "?"
|
|
|
|
ip = get_client_ip(request)
|
|
geo = get_geo()
|
|
data = geo.lookup(ip)
|
|
|
|
response: Response = await call_next(request)
|
|
|
|
for k, v in (geo_headers(data) or {}).items():
|
|
ks = _ascii(str(k))
|
|
vs = _ascii(str(v))
|
|
# gwarancja, że przejdzie przez kodowanie nagłówków Starlette
|
|
response.headers[ks] = vs
|
|
|
|
return response
|
|
|
|
|
|
app.add_middleware(BaseHTTPMiddleware, dispatch=add_geo_headers)
|
|
|
|
|
|
@app.get("/favicon.ico")
|
|
async def favicon():
|
|
return Response(status_code=204)
|
|
|
|
|
|
@app.api_route("/", methods=["GET", "HEAD"])
|
|
async def root(request: Request):
|
|
ua = request.headers.get("user-agent", "").lower()
|
|
ip = get_client_ip(request).strip()
|
|
|
|
if any(x in ua for x in ["mozilla", "chrome", "safari", "edge", "firefox"]):
|
|
if request.method == "HEAD":
|
|
return Response(status_code=404)
|
|
return JSONResponse({"detail": "Not Found"}, status_code=404)
|
|
|
|
if request.method == "HEAD":
|
|
return Response(status_code=200)
|
|
return PlainTextResponse(ip + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(
|
|
"app.main:app",
|
|
host=settings.host,
|
|
port=settings.port,
|
|
log_level=settings.log_level,
|
|
proxy_headers=True,
|
|
forwarded_allow_ips="*",
|
|
# access_log=True
|
|
)
|