Forecasting with NeuralGCM-0.7#

The same forecast as the forecast_quickstart notebook, with the flagship 0.7° deterministic model (NeuralGCM-0.7): a TL255 dynamical core with data on a 512×256 Gaussian grid — 16× the columns of the 2.8° model — and 31M parameters. On a recent GPU a 4-day forecast takes about two minutes eagerly; for longer rollouts use model.compile(state, forcings, cudagraphs=True) first.

Needs network access for the model checkpoint and the ERA5 data (both anonymous GCS), a GPU, and ~6 GB of GPU memory.

import pathlib

import numpy as np
import torch
import xarray

from dinosaur_torch import horizontal_interpolation
from dinosaur_torch import xarray_utils
import neuralgcm_torch as neuralgcm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Load the model#

As always in neuralgcm-torch, the original checkpoint is converted once into a plain torch.save file and cached:

from neuralgcm_torch import pretrained

# Fetched from the Hugging Face Hub on first use (cached), or reused from
# a local checkpoints/ directory if present.
converted_path = pretrained.fetch_checkpoint('deterministic_0_7_deg', local_root='checkpoints')
model = neuralgcm.PressureLevelModel.from_checkpoint(converted_path, device=device)
(model.longitudes.size, model.latitudes.size,
 sum(p.numel() for p in model.model.parameters()))
(512, 256, 31123612)

Load and regrid ERA5#

era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'
full_era5 = xarray.open_zarr(
    era5_path, chunks=None, storage_options=dict(token='anon')
)

demo_start_time = '2020-02-14'
demo_end_time = '2020-02-18'
data_inner_steps = 24  # process every 24th hour

sliced_era5 = (
    full_era5
    [model.input_variables + model.forcing_variables]
    .pipe(
        xarray_utils.selective_temporal_shift,
        variables=model.forcing_variables,
        time_shift='24 hours',
    )
    .sel(time=slice(demo_start_time, demo_end_time, data_inner_steps))
    .compute()
)

era5_grid = xarray_utils.grid_spec_from_dataset(full_era5)
regridder = horizontal_interpolation.ConservativeRegridder(
    era5_grid, model.data_grid, skipna=True, device=device
)
eval_era5 = xarray_utils.regrid_horizontal(sliced_era5, regridder)
eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5)

Make the forecast#

inner_steps = 24  # save model outputs once every 24 hours
outer_steps = 4 * 24 // inner_steps  # total of 4 days
timedelta = np.timedelta64(1, 'h') * inner_steps
times = np.arange(outer_steps) * inner_steps  # time axis in hours

inputs = model.inputs_from_xarray(eval_era5.isel(time=0))
input_forcings = model.forcings_from_xarray(eval_era5.isel(time=0))
initial_state = model.encode(inputs, input_forcings, rng=42)

# use persistence for forcing variables (SST and sea ice cover)
all_forcings = model.forcings_from_xarray(eval_era5.head(time=1))

final_state, predictions = model.unroll(
    initial_state,
    all_forcings,
    steps=outer_steps,
    timedelta=timedelta,
    start_with_input=True,
)
predictions_ds = model.data_to_xarray(predictions, times=times)

Compare forecast to ERA5#

At 0.7° the forecast resolves much finer structures than the 2.8° model of the quickstart — compare the specific-humidity fields below with the quickstart’s. Day-4 skill (area-weighted T850 RMSE) against ERA5:

target_trajectory = model.inputs_from_xarray(
    eval_era5
    .thin(time=(inner_steps // data_inner_steps))
    .isel(time=slice(outer_steps))
)
target_data_ds = model.data_to_xarray(target_trajectory, times=times)

combined_ds = xarray.concat([target_data_ds, predictions_ds], 'model')
combined_ds.coords['model'] = ['ERA5', 'NeuralGCM-0.7']

w = np.cos(np.deg2rad(combined_ds.latitude))
t850 = combined_ds.sel(level=850).temperature.isel(time=-1)
rmse = float(np.sqrt(
    ((t850.sel(model='NeuralGCM-0.7') - t850.sel(model='ERA5')) ** 2)
    .weighted(w).mean()
))
persistence = float(np.sqrt(
    ((combined_ds.sel(model='ERA5', level=850).temperature.isel(time=0)
      - t850.sel(model='ERA5')) ** 2).weighted(w).mean()
))
print(f'day-4 T850 RMSE: forecast {rmse:.2f} K, persistence {persistence:.2f} K')
day-4 T850 RMSE: forecast 1.00 K, persistence 4.28 K
combined_ds.specific_humidity.sel(level=850).plot(
    x='longitude', y='latitude', row='time', col='model', robust=True, aspect=2, size=2
);
_images/b18e44379cbe892c14fccf48e94894792661b00a03bac39ebbfbd9275f2337c3.png