#!/usr/bin/env python3

import pygrib
import numpy as np
from PIL import Image
import os
import json

# -----------------------------
# CONFIG
# -----------------------------
GRIB_FILE = "/var/noaaport/data/grib/data/grids/gfs/3/gfs_3_2026060218.grb2"
OUTPUT_DIR = "./wind_pngs"
MANIFEST_FILE = os.path.join(OUTPUT_DIR, "manifest.json")

MAX_WIND = 50.0  # fixed global scaling (same as Script 1)

os.makedirs(OUTPUT_DIR, exist_ok=True)

# -----------------------------
# ENCODING (IDENTICAL TO SCRIPT 1 LOGIC)
# -----------------------------
def encode_uv(data):
    """
    -MAX_WIND .. +MAX_WIND → 0..255 (centered at 128)
    """
    normalized = data / MAX_WIND
    encoded = (normalized * 0.5 + 0.5) * 255.0
    return np.clip(encoded, 0, 255).astype(np.uint8)

def encode_speed(speed):
    """
    0..MAX_WIND → 0..255
    """
    normalized = speed / MAX_WIND
    encoded = normalized * 255.0
    return np.clip(encoded, 0, 255).astype(np.uint8)

# -----------------------------
# FILE NAMING
# -----------------------------
def make_filename(level_type, level):
    lt = level_type.lower()

    if lt == "heightaboveground":
        return f"wind_{level}m.png"

    if lt == "isobaricinhpa":
        return f"wind_{level}hPa.png"

    if lt == "tropopause":
        return "wind_tropopause.png"

    if lt == "maxwind":
        return "wind_maxwind.png"

    if lt == "pressurefromgroundlayer":
        return f"wind_{level}m_agl.png"

    return f"wind_{lt}_{level}.png"

# -----------------------------
# KEY NORMALIZATION (10m fix only)
# -----------------------------
def get_wind_key(g):
    if g.shortName in ["10u", "10v"]:
        return ("heightAboveGround", 10)
    return (g.typeOfLevel, g.level)

# -----------------------------
# LOAD GRIB
# -----------------------------
print("Opening GRIB...")
grbs = pygrib.open(GRIB_FILE)

# -----------------------------
# COLLECT U/V FIELDS
# Eagerly read .values while the file cursor is on each message.
# Storing the message object and calling .values later causes pygrib
# to re-seek, and the 10u/10v messages end up returning data from
# a different position in the file.
# -----------------------------
print("Collecting wind fields...")

u_fields = {}
v_fields = {}

for g in grbs:
    if g.shortName in ["u", "10u"]:
        u_fields[get_wind_key(g)] = g.values   # read NOW, not later
    elif g.shortName in ["v", "10v"]:
        v_fields[get_wind_key(g)] = g.values   # read NOW, not later

levels = sorted(set(u_fields.keys()) & set(v_fields.keys()))
print(f"Found {len(levels)} wind levels")

# -----------------------------
# PROCESS LEVELS
# -----------------------------
manifest = {
    "maxWind": MAX_WIND,
    "levels": []
}

for level_type, level in levels:
    print(f"Processing: {level_type} {level}")

    # Already numpy arrays — no .values call needed
    u = u_fields[(level_type, level)]
    v = v_fields[(level_type, level)]

    # Compute wind speed
    speed = np.sqrt(u**2 + v**2)

    # Encode (SAME LOGIC AS SCRIPT 1)
    r = encode_uv(u)
    g = encode_uv(v)
    b = encode_speed(speed)
    a = np.full_like(r, 255, dtype=np.uint8)

    rgba = np.stack([r, g, b, a], axis=-1)

    # Save PNG
    filename = make_filename(level_type, level)
    output_path = os.path.join(OUTPUT_DIR, filename)

    Image.fromarray(rgba, mode="RGBA").save(output_path)

    # Manifest entry
    manifest["levels"].append({
        "levelType": level_type,
        "level": level,
        "file": filename
    })

# -----------------------------
# SAVE MANIFEST
# -----------------------------
with open(MANIFEST_FILE, "w") as f:
    json.dump(manifest, f, indent=2)

print("Done.")
print(f"Saved to: {OUTPUT_DIR}")
print(f"Manifest: {MANIFEST_FILE}")
