ERA5 Forecast Inference#

In this notebook, we will use the model trained using ERA5 data Jan 2018 to perform Forecasting on ERA5 from Jan 2019. We will use a combination of what we have learned so far during this training xarray, and XGBoost.

Import libraries#

import numpy as np
import xarray as xr
import xgboost as xgb
from matplotlib import pyplot as plt

Define constants#

ZARR_URL = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
LAGS = list(range(1, 7))
START_DATE = "2018-12-26"  # 6 lags before Jan 1
FORECAST_START = "2019-01-01"
FORECAST_END = "2019-01-31"
MODEL_PATH = "../chapter4/era5_forecast/models/xgb_model_old.json"

Initialize the model#

booster = xgb.Booster()
booster.load_model(MODEL_PATH)

Function to create lagged features#

def create_lagged_features(da, lags):
    lagged = [da.shift(time=lag).rename(f"lag_{lag}") for lag in lags]
    return xr.merge(lagged + [da.rename("target")])

Prepare the data#

ds = xr.open_zarr(
        ZARR_URL,
        chunks={"time": 1},  # type: ignore
        storage_options={"token": "anon"},
    )

temp = ds["2m_temperature"].sel(
        time=slice(START_DATE, FORECAST_END),
    )
daily = temp.resample(time="1D").mean()
daily = daily.compute()

Perform inference#

predictions = []
dates = np.array(daily.sel(time=slice(FORECAST_START, FORECAST_END)).time)
for date in dates:
    i = np.where(daily.time == date)[0][0]
    if i < max(LAGS):
        continue  # skip if not enough lags

    lagged_stack = [daily.isel(time=i - lag) for lag in LAGS]
    X_pred = xr.concat(lagged_stack, dim="feature")
    X_pred = X_pred.stack(sample=("latitude", "longitude")).transpose(
        "sample", "feature"
    )
    dmatrix = xgb.DMatrix(X_pred.data)
    y_pred = booster.predict(dmatrix)
    n_lat = len(X_pred.latitude)
    n_lon = len(X_pred.longitude)

    # Reshape to spatial grid
    pred_grid = y_pred.reshape(len(daily.latitude), len(daily.longitude))
    da = xr.DataArray(
        pred_grid,
        coords={"latitude": daily.latitude, "longitude": daily.longitude},
        dims=["latitude", "longitude"],
        name="prediction",
    )
    da = da.expand_dims(time=[date])
    predictions.append(da)

Compute error metrics#

pred_stack = xr.concat(predictions, dim="time")
truth = daily.sel(time=slice(FORECAST_START, FORECAST_END))
error = pred_stack - truth
mae = xr.ufuncs.abs(error.mean(dim=["latitude", "longitude"]))
rmse = np.sqrt((error**2).mean(dim=["latitude", "longitude"]))

Plot the results#

example_date = "2019-01-20"
truth_day = truth.sel(time=example_date)
pred_day = pred_stack.sel(time=example_date)
error_day = error.sel(time=example_date)
fig, axs = plt.subplots(1, 3, figsize=(18, 5))
truth_day.plot(ax=axs[0], cmap="coolwarm",vmin=200, vmax=320)
axs[0].set_title(f"Ground Truth: {example_date}")
pred_day.plot(ax=axs[1], cmap="coolwarm",vmin=200, vmax=320)
axs[1].set_title(f"Prediction: {example_date}")
error_day.plot(ax=axs[2], cmap="coolwarm")
axs[2].set_title(f"Error: {example_date}")
plt.tight_layout()
plt.show()
../_images/cc56e0fe9b528bac0aa7d99e7ddf1f0c10aa367693ea001595a4d95c1dc02707.png
mae.plot(label="MAE", figsize=(10, 4))
rmse.plot(label="RMSE")
plt.legend()
plt.title("Forecast Error Over Time")
Text(0.5, 1.0, 'Forecast Error Over Time')
../_images/1d7f5b43bd6380f48c09badfcdaacbf0cb3897bb532002dbadfa64749a1a9f2f.png