# sankey.py
import io, json, sys, argparse
from pathlib import Path
import numpy as np
import pandas as pd

POSITIONS_SHEET = "nodes_positions"
DEFAULT_NODE_COLOR = "#CCCCCC"
DEFAULT_FLOW_COLOR = "#009EDB"
TRANSPARENT        = "rgba(0,0,0,0)"

def read_excel_locked(filepath, **pd_kwargs):
    with open(filepath, "rb") as fh:
        buf = io.BytesIO(fh.read())
    return pd.read_excel(buf, engine="openpyxl", **pd_kwargs)

def _valid_color(val) -> bool:
    return isinstance(val, str) and val.strip() != ""

def add_alpha(hex_color: str, alpha=1.0) -> str:
    hex_color = hex_color.lstrip("#")
    r, g, b = (int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    return f"rgba({r},{g},{b},{alpha})"

def build_sankey_from_excel(excel_path: str):
    # 1) Config
    try:
        cfg = read_excel_locked(excel_path, sheet_name="config", header=None)
        config_title = str(cfg.iloc[0, 1])
        config_sub   = str(cfg.iloc[1, 1]) if pd.notna(cfg.iloc[1, 1]) else ""
    except Exception:
        config_title, config_sub = "Sankey Diagram", ""

    # 2) Main data (first sheet)
    data = read_excel_locked(excel_path, sheet_name="data", header=0)
    data.columns = data.columns.str.strip().str.lower()
    if "flow value" in data.columns:
        data = data.rename(columns={"flow value": "value"})

    data["value"] = pd.to_numeric(data["value"], errors="coerce")
    data = data.dropna(subset=["value"])
    data["display_value"] = data["value"]
    data["is_zero"] = data["value"] == 0
    data["value"]   = data["value"].apply(lambda x: 0.001 if x == 0 else float(x))

    valid_nodes = set(data["source"]) | set(data["target"])

    # 3) Positions
    try:
        with open(excel_path, "rb") as fh:
            xl_buf = io.BytesIO(fh.read())
        xl = pd.ExcelFile(xl_buf, engine="openpyxl")
    except Exception:
        xl = None

    if xl and POSITIONS_SHEET in xl.sheet_names:
        pos_df = xl.parse(sheet_name=POSITIONS_SHEET)
        pos_df["node"] = pos_df["node"].astype(str).str.strip()
        pos_df = pos_df[pos_df["node"].isin(valid_nodes)]
        nodes_order = pos_df["node"].tolist()
        for col in ["node_color", "incoming_flow_color", "outgoing_flow_color"]:
            if col not in pos_df.columns:
                pos_df[col] = ""
        if "x" not in pos_df.columns: pos_df["x"] = np.linspace(0.1, 0.9, num=len(pos_df))
        if "y" not in pos_df.columns: pos_df["y"] = 0.5
    else:
        nodes_order = sorted(valid_nodes)
        pos_df = pd.DataFrame({
            "node": nodes_order,
            "x": np.linspace(0.1, 0.9, num=len(nodes_order)),
            "y": [0.5] * len(nodes_order),
            "node_color":          ["" for _ in nodes_order],
            "incoming_flow_color": ["" for _ in nodes_order],
            "outgoing_flow_color": ["" for _ in nodes_order],
        })

    pos_x = dict(zip(pos_df["node"], pos_df["x"]))
    pos_y = dict(zip(pos_df["node"], pos_df["y"]))
    node_color_map  = dict(zip(pos_df["node"], pos_df["node_color"]))
    incoming_colors = dict(zip(pos_df["node"], pos_df["incoming_flow_color"]))
    outgoing_colors = dict(zip(pos_df["node"], pos_df["outgoing_flow_color"]))

    totals, in_cnt, out_cnt = {}, {}, {}
    for n in nodes_order:
        in_links  = data[data["target"] == n]
        out_links = data[data["source"] == n]
        totals[n]  = float(max(in_links["display_value"].sum(), out_links["display_value"].sum()) or 0.001)
        in_cnt[n]  = int(len(in_links))
        out_cnt[n] = int(len(out_links))

    invisible_nodes = set()
    nodes = []
    for n in nodes_order:
        color = node_color_map.get(n, "")
        color = color if _valid_color(color) else DEFAULT_NODE_COLOR
        if n in invisible_nodes:
            label = " "
            color = TRANSPARENT
        else:
            label = f"{n} ({totals[n]:,.0f})"
        nodes.append({
            "id": n, "label": label, "color": color,
            "x": float(pos_x.get(n)) if pos_x.get(n) is not None else None,
            "y": float(pos_y.get(n)) if pos_y.get(n) is not None else None,
            "inCount": in_cnt[n], "outCount": out_cnt[n],
            "invisible": (n in invisible_nodes)
        })

    links = []
    for _, row in data.iterrows():
        src, tgt = row["source"], row["target"]
        if row["is_zero"]:
            color = add_alpha(DEFAULT_FLOW_COLOR, 0.25)
        else:
            c_src = outgoing_colors.get(src, "")
            if _valid_color(c_src):
                color = c_src
            else:
                c_tgt = incoming_colors.get(tgt, "")
                color = c_tgt if _valid_color(c_tgt) else DEFAULT_FLOW_COLOR

        links.append({
            "source": src, "target": tgt,
            "value":  float(row["value"]),
            "displayValue": float(row["display_value"]),
            "color":  color, "is_zero": bool(row["is_zero"])
        })

    meta = {"title": config_title, "subtitle": config_sub,
            "defaults": {"nodeWidth": 15, "nodePadding": 30}}
    return {"meta": meta, "nodes": nodes, "links": links}

def main():
    ap = argparse.ArgumentParser(description="Build Sankey JSON/JS from Excel.")
    ap.add_argument("-i", "--excel", default="./data/sankey_data.xlsx",
                    help="Path to Excel input file.")
    ap.add_argument("-o", "--out-json", default="sankey.json",
                    help="Output JSON file name.")
    ap.add_argument("--out-js", default=None,
                    help="Optional output JS file name (writes window.SANKEY_DATA).")
    args = ap.parse_args()

    try:
        out_dict = build_sankey_from_excel(args.excel)
    except Exception as e:
        print(f"Failed to build Sankey from '{args.excel}': {e}", file=sys.stderr)
        sys.exit(1)

    Path(args.out_json).write_text(json.dumps(out_dict, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"Wrote {args.out_json} with {len(out_dict['nodes'])} nodes and {len(out_dict['links'])} links.")

    if args.out_js:
        Path(args.out_js).write_text(
            "window.SANKEY_DATA = " + json.dumps(out_dict, ensure_ascii=False) + ";",
            encoding="utf-8"
        )
        print(f"Also wrote {args.out_js} (for file:// usage).")

if __name__ == "__main__":
    main()
