Forecasting with NeuralGCM-0.7#
The same forecast as the
forecast_quickstart notebook, with the
flagship 0.7° deterministic model (NeuralGCM-0.7): a TL255 dynamical
core with data on a 512×256 Gaussian grid — 16× the columns of the 2.8°
model — and 31M parameters. On a recent GPU a 4-day forecast takes about
two minutes eagerly; for longer rollouts use
model.compile(state, forcings, cudagraphs=True) first.
Needs network access for the model checkpoint and the ERA5 data (both anonymous GCS), a GPU, and ~6 GB of GPU memory.
import pathlib
import numpy as np
import torch
import xarray
from dinosaur_torch import horizontal_interpolation
from dinosaur_torch import xarray_utils
import neuralgcm_torch as neuralgcm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Load the model#
As always in neuralgcm-torch, the original checkpoint is converted
once into a plain torch.save file and cached:
from neuralgcm_torch import pretrained
# Fetched from the Hugging Face Hub on first use (cached), or reused from
# a local checkpoints/ directory if present.
converted_path = pretrained.fetch_checkpoint('deterministic_0_7_deg', local_root='checkpoints')
model = neuralgcm.PressureLevelModel.from_checkpoint(converted_path, device=device)
(model.longitudes.size, model.latitudes.size,
sum(p.numel() for p in model.model.parameters()))
(512, 256, 31123612)
Load and regrid ERA5#
era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'
full_era5 = xarray.open_zarr(
era5_path, chunks=None, storage_options=dict(token='anon')
)
demo_start_time = '2020-02-14'
demo_end_time = '2020-02-18'
data_inner_steps = 24 # process every 24th hour
sliced_era5 = (
full_era5
[model.input_variables + model.forcing_variables]
.pipe(
xarray_utils.selective_temporal_shift,
variables=model.forcing_variables,
time_shift='24 hours',
)
.sel(time=slice(demo_start_time, demo_end_time, data_inner_steps))
.compute()
)
era5_grid = xarray_utils.grid_spec_from_dataset(full_era5)
regridder = horizontal_interpolation.ConservativeRegridder(
era5_grid, model.data_grid, skipna=True, device=device
)
eval_era5 = xarray_utils.regrid_horizontal(sliced_era5, regridder)
eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5)
Make the forecast#
inner_steps = 24 # save model outputs once every 24 hours
outer_steps = 4 * 24 // inner_steps # total of 4 days
timedelta = np.timedelta64(1, 'h') * inner_steps
times = np.arange(outer_steps) * inner_steps # time axis in hours
inputs = model.inputs_from_xarray(eval_era5.isel(time=0))
input_forcings = model.forcings_from_xarray(eval_era5.isel(time=0))
initial_state = model.encode(inputs, input_forcings, rng=42)
# use persistence for forcing variables (SST and sea ice cover)
all_forcings = model.forcings_from_xarray(eval_era5.head(time=1))
final_state, predictions = model.unroll(
initial_state,
all_forcings,
steps=outer_steps,
timedelta=timedelta,
start_with_input=True,
)
predictions_ds = model.data_to_xarray(predictions, times=times)
Compare forecast to ERA5#
At 0.7° the forecast resolves much finer structures than the 2.8° model of the quickstart — compare the specific-humidity fields below with the quickstart’s. Day-4 skill (area-weighted T850 RMSE) against ERA5:
target_trajectory = model.inputs_from_xarray(
eval_era5
.thin(time=(inner_steps // data_inner_steps))
.isel(time=slice(outer_steps))
)
target_data_ds = model.data_to_xarray(target_trajectory, times=times)
combined_ds = xarray.concat([target_data_ds, predictions_ds], 'model')
combined_ds.coords['model'] = ['ERA5', 'NeuralGCM-0.7']
w = np.cos(np.deg2rad(combined_ds.latitude))
t850 = combined_ds.sel(level=850).temperature.isel(time=-1)
rmse = float(np.sqrt(
((t850.sel(model='NeuralGCM-0.7') - t850.sel(model='ERA5')) ** 2)
.weighted(w).mean()
))
persistence = float(np.sqrt(
((combined_ds.sel(model='ERA5', level=850).temperature.isel(time=0)
- t850.sel(model='ERA5')) ** 2).weighted(w).mean()
))
print(f'day-4 T850 RMSE: forecast {rmse:.2f} K, persistence {persistence:.2f} K')
day-4 T850 RMSE: forecast 1.00 K, persistence 4.28 K
combined_ds.specific_humidity.sel(level=850).plot(
x='longitude', y='latitude', row='time', col='model', robust=True, aspect=2, size=2
);