import os, json, math, heapq, requests, urllib3
try:
    import msgpack
except ImportError:
    print("[!] Zainstaluj msgpack: pip3 install msgpack")
    exit(1)
urllib3.disable_warnings()

BASE_PATH = "/var/lib/pterodactyl/volumes/cc7ccc03-03bd-4bde-8488-3057d3803420/squaremap/web"
MTR_SAVES_PATH = "/var/lib/pterodactyl/volumes/cc7ccc03-03bd-4bde-8488-3057d3803420/world/mtr/minecraft/overworld"
STATIONS_FILE = f"{BASE_PATH}/stations.json"
OUTPUT_FILE = f"{BASE_PATH}/final_map_data.json"
API_ROUTES_URL = "https://mtr.ciapongi.szablix.pl/mtr/api/map/stations-and-routes"
API_RAILS_URL = "https://mtr.ciapongi.szablix.pl/mtr/api/map/rails"

def normalize(s): return str(s).strip().lower().replace(" ", "")
def dist_2d(p1, p2): return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)

def fetch_json(url):
    try:
        r = requests.get(url, verify=False, timeout=15)
        if r.status_code == 200: return r.json()
    except: pass
    return {}

# ZMIANA: Pobieramy nie tylko współrzędne, ale też nazwy peronów z MsgPack
def load_msgpack_data(folder_name):
    mapping = {}
    full_path = os.path.join(MTR_SAVES_PATH, folder_name)
    if not os.path.exists(full_path): return mapping
    for root, _, files in os.walk(full_path):
        for file in files:
            if file.startswith("."): continue
            try:
                with open(os.path.join(root, file), "rb") as f:
                    data = msgpack.unpackb(f.read(), raw=False)
                    p1, p2 = data.get("position1", {}), data.get("position2", {})
                    name = data.get("name", "Nieznany")
                    if p1 and p2:
                        k = tuple(sorted([(p1.get('x',0), p1.get('y',0), p1.get('z',0)),
                                          (p2.get('x',0), p2.get('y',0), p2.get('z',0))]))
                        mapping[k] = name
            except: pass
    return mapping

def dijkstra_on_rails(rail_dict, adj, start_rid, end_rid):
    if start_rid == end_rid: return [start_rid]
    queue = [(0, start_rid)]
    came_from = {start_rid: None}
    cost_so_far = {start_rid: 0}
    
    while queue:
        current_cost, curr = heapq.heappop(queue)
        if curr == end_rid: break
        
        for nxt in adj.get(curr, []):
            dist = rail_dict.get(nxt, {}).get('length', 10)
            new_cost = cost_so_far[curr] + dist
            if nxt not in cost_so_far or new_cost < cost_so_far[nxt]:
                cost_so_far[nxt] = new_cost
                heapq.heappush(queue, (new_cost, nxt))
                came_from[nxt] = curr
                
    if end_rid not in came_from: return None
    path = []
    c = end_rid
    while c is not None:
        path.append(c)
        c = came_from[c]
    path.reverse()
    return path

def main():
    print("=== GENERATOR V20 (ZOPTYMALIZOWANE PERONY I CZYSTE TRASY) ===")
    platform_data = load_msgpack_data("platforms")
    siding_data = load_msgpack_data("sidings")

    api_routes = fetch_json(API_ROUTES_URL).get('data', {}).get('routes', [])
    api_routes_map = {normalize(r.get('name', '')): r for r in api_routes}
    api_rails = fetch_json(API_RAILS_URL)
    rails_list = api_rails.get('data', {}).get('rails', []) if isinstance(api_rails, dict) else api_rails
    
    rail_dict, adj, grid = {}, {}, {}
    GRID_SIZE = 32
    
    platforms_out, sidings_out, infra_out = [], [], []
    nodes_set = set()
    
    for r in rails_list:
        rid = r['id']
        rail_dict[rid] = r
        adj[rid] = set(r.get('connectedRails', []))
        
        pts = [[c['x'], c['z']] for c in sorted(r.get('curvePoints', []), key=lambda x: x['progress'])]
        if not pts: continue
        
        nodes_set.add((round(pts[0][0], 2), round(pts[0][1], 2)))
        nodes_set.add((round(pts[-1][0], 2), round(pts[-1][1], 2)))
        
        for pt in pts:
            gx, gz = int(pt[0]//GRID_SIZE), int(pt[1]//GRID_SIZE)
            if (gx, gz) not in grid: grid[(gx, gz)] = set()
            grid[(gx, gz)].add(rid)

        p1, p2 = r.get('position1', {}), r.get('position2', {})
        k = tuple(sorted([(p1.get('x',0), p1.get('y',0), p1.get('z',0)),
                          (p2.get('x',0), p2.get('y',0), p2.get('z',0))]))
        
        # ZMIANA: Zapisujemy też nazwę do JSONa!
        if k in platform_data: platforms_out.append({"path": pts, "name": platform_data[k]})
        elif k in siding_data: sidings_out.append({"path": pts, "name": siding_data[k]})
        else: infra_out.append(pts)

    for r in rails_list:
        for connected in r.get('connectedRails', []):
            if connected in adj: adj[connected].add(r['id'])

    def find_closest_rail(pt):
        gx, gz = int(pt[0]//GRID_SIZE), int(pt[1]//GRID_SIZE)
        cands = set()
        for dx in [-1,0,1]:
            for dz in [-1,0,1]: cands.update(grid.get((gx+dx, gz+dz), []))
        best_rail, min_d = None, float('inf')
        for rid in cands:
            for c in rail_dict[rid].get('curvePoints', []):
                d = dist_2d(pt, (c['x'], c['z']))
                if d < min_d: min_d = d; best_rail = rid
        return best_rail, min_d

    try:
        with open(STATIONS_FILE, 'r') as f: data = json.load(f)
    except Exception: return
        
    final_routes = []
    
    for route in data.get('routes', []):
        norm_name = normalize(route.get('name', ''))
        api_data = api_routes_map.get(norm_name) or next((v for k, v in api_routes_map.items() if norm_name in k), None)
        
        if api_data:
            route['id'] = api_data.get('id')
            route['stations'] = api_data.get('stations', [])
            
            stations = route['stations']
            route_rids = []
            
            if len(stations) > 1:
                for i in range(len(stations) - 1):
                    s1, s2 = stations[i], stations[i+1]
                    r1, d1 = find_closest_rail((s1['x'], s1['z']))
                    r2, d2 = find_closest_rail((s2['x'], s2['z']))
                    if r1 and r2:
                        sub_path = dijkstra_on_rails(rail_dict, adj, r1, r2)
                        if sub_path:
                            if route_rids and route_rids[-1] == sub_path[0]: route_rids.extend(sub_path[1:])
                            else: route_rids.extend(sub_path)
            
            if route_rids:
                detailed_path = []
                last_pt = None
                for rid in route_rids:
                    r = rail_dict.get(rid)
                    pts = [[c['x'], c['z']] for c in sorted(r.get('curvePoints', []), key=lambda x: x['progress'])]
                    if not pts: continue
                    
                    if last_pt is not None:
                        if dist_2d(last_pt, pts[-1]) < dist_2d(last_pt, pts[0]): pts.reverse()
                    
                    if detailed_path and dist_2d(detailed_path[-1], pts[0]) < 2.0: detailed_path.extend(pts[1:])
                    else: detailed_path.extend(pts)
                    last_pt = pts[-1]
                    
                route['path'] = detailed_path
            else:
                route['path'] = [[p['x'], p['z']] for p in api_data.get('path', [])]
        else:
            if 'stations' not in route: route['stations'] = []

        final_routes.append(route)
        
    with open(OUTPUT_FILE, 'w') as f:
        json.dump({
            "stations": data.get('stations', []), 
            "routes": final_routes, 
            "platforms": platforms_out,
            "sidings": sidings_out,
            "infrastructure": infra_out,
            "nodes": list(nodes_set)
        }, f)
        
    print("✅ GOTOWE! Wyeksportowano zaktualizowane perony i czyste trasy.")

if __name__ == "__main__": main()
