import ipaddress
import subprocess
import argparse
import csv
import json
import time
import threading
import itertools
from concurrent.futures import ThreadPoolExecutor, as_completed

default_oids = {
    "sysName": "1.3.6.1.2.1.1.5.0",
    #"sysDescr": "1.3.6.1.2.1.1.1.0",
    #"sysLocation": "1.3.6.1.2.1.1.6.0",
    #"sysContact": "1.3.6.1.2.1.1.4.0"
}

def get_oid_value(ip, community, oid, timeout=1):
    try:
        result = subprocess.run(
            ["snmpget", "-v2c", "-c", community, str(ip), oid],
            capture_output=True, text=True, timeout=timeout
        )
        if result.returncode == 0 and "STRING" in result.stdout:
            return result.stdout.split("STRING:")[-1].strip()
    except subprocess.TimeoutExpired:
        return "Timeout"
    return "Brak danych"

def query_all_oids(ip, community, oids, timeout=1):
    values = {name: get_oid_value(ip, community, oid, timeout) for name, oid in oids.items()}
    has_data = any(val not in ("Timeout", "Brak danych") for val in values.values())
    if has_data:
        return (str(ip), values)
    return None

def scan_subnet(subnet, community, oids, counter, total, lock):
    results = []
    for ip in subnet.hosts():
        result = query_all_oids(ip, community, oids, timeout=1)
        if result:
            results.append(result)
        with lock:
            counter[0] += 1
            print(f"\rSkanowanie IP: {counter[0]}/{total}", end="", flush=True)
    return results

def split_subnet(supernet, new_prefix=24):
    try:
        network = ipaddress.IPv4Network(supernet)
        return list(network.subnets(new_prefix=new_prefix))
    except ValueError as e:
        print(f"Błąd podsieci: {e}")
        return []

def read_subnets_from_file(filename):
    subnets = []
    try:
        with open(filename, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line:
                    subnets.append(line)
    except Exception as e:
        print(f"Błąd podczas wczytywania pliku podsieci: {e}")
    return subnets

def save_to_csv(results, oids, filename="wyniki.csv"):
    with open(filename, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        headers = ["IP"] + list(oids.keys())
        writer.writerow(headers)
        for ip, values in results:
            row = [ip] + [values.get(key, "") for key in oids.keys()]
            writer.writerow(row)

def main():
    parser = argparse.ArgumentParser(description="Skaner SNMP z podziałem podsieci i eksportem do CSV.")
    parser.add_argument("subnets", nargs="*", help="Podsieci w formacie CIDR (np. 172.16.0.0/16)")
    parser.add_argument("-s", "--subnet-file", help="Plik tekstowy z listą podsieci CIDR (jedna na linię)")
    parser.add_argument("-c", "--community", default="public", help="SNMP community (domyślnie: public)")
    parser.add_argument("-o", "--oids", help="OID-y w formacie JSON, np. '{\"sysDescr\":\"1.3.6.1.2.1.1.1.0\"}'")
    parser.add_argument("-w", "--workers", type=int, default=10, help="Liczba równoległych wątków (domyślnie: 10)")
    parser.add_argument("-p", "--prefix", type=int, default=24, help="Wielkość podsieci do podziału (domyślnie: 24)")
    parser.add_argument("-f", "--file", default="wyniki.csv", help="Nazwa pliku CSV do zapisu (domyślnie: wyniki.csv)")
    args = parser.parse_args()

    oids = json.loads(args.oids) if args.oids else default_oids

    all_input_subnets = []
    if args.subnet_file:
        all_input_subnets.extend(read_subnets_from_file(args.subnet_file))
    if args.subnets:
        all_input_subnets.extend(args.subnets)

    if not all_input_subnets:
        print("Brak podsieci do przeskanowania (ani z linii poleceń, ani z pliku).")
        return

    subnets_to_scan = []
    for subnet in all_input_subnets:
        subnets_to_scan.extend(split_subnet(subnet, new_prefix=args.prefix))

    # Oblicz łączną liczbę hostów
    total_ips = sum(1 for subnet in subnets_to_scan for _ in subnet.hosts())
    counter = [0]
    lock = threading.Lock()

    print(f"{'IP':<15} " + " ".join([f"{name:<25}" for name in oids.keys()]))
    print("-" * (15 + 26 * len(oids)))

    all_results = []
    with ThreadPoolExecutor(max_workers=args.workers) as executor:
        futures = {
            executor.submit(scan_subnet, subnet, args.community, oids, counter, total_ips, lock): subnet
            for subnet in subnets_to_scan
        }
        for future in as_completed(futures):
            subnet_results = future.result()
            for ip_str, values in subnet_results:
                all_results.append((ip_str, values))
                print(f"\n{ip_str:<15} " + " ".join([f"{values[name]:<25}" for name in oids.keys()]))

    print("\nSkanowanie zakończone.")
    save_to_csv(all_results, oids, filename=args.file)

if __name__ == "__main__":
    main()