#!/usr/bin/env python3

import pygrib
import numpy as np
from PIL import Image

GRIB_FILE = "/var/noaaport/data/grib/data/grids/gfs/3/gfs_3_2026060206.grb2"
OUTPUT_FILE = "wind10.png"

MAX_WIND = 50.0

print("Opening GRIB...")
grbs = pygrib.open(GRIB_FILE)

u10 = grbs.select(shortName='10u')[0]
v10 = grbs.select(shortName='10v')[0]

u = u10.values
v = v10.values

print(f"Grid shape: {u.shape}")

speed = np.sqrt(u**2 + v**2)

def encode_uv(data, max_wind):
    """
    Encode -max_wind .. +max_wind → 0..255
    """
    normalized = data / max_wind          # -1 .. 1
    encoded = (normalized * 0.5 + 0.5)    # 0 .. 1
    encoded = encoded * 255.0             # 0 .. 255
    return np.clip(encoded, 0, 255).astype(np.uint8)


def encode_speed(speed, max_wind):
    """
    Encode 0..max_wind → 0..255
    """
    normalized = speed / max_wind
    encoded = normalized * 255.0
    return np.clip(encoded, 0, 255).astype(np.uint8)

print("Encoding channels...")

r = encode_uv(u, MAX_WIND)        # U
g = encode_uv(v, MAX_WIND)        # V
b = encode_speed(speed, MAX_WIND) # speed
a = np.full_like(r, 255, dtype=np.uint8)

rgba = np.stack([r, g, b, a], axis=-1)

print("Saving PNG...")

img = Image.fromarray(rgba, mode='RGBA')
img.save(OUTPUT_FILE)

print(f"Done. Saved: {OUTPUT_FILE}")
