Deep-dive into trained models#

Recall the high-level structure of NeuralGCM models, from Figure 1 of the NeuralGCM paper:

In neuralgcm-torch this structure is a plain torch.nn.Module tree: a StochasticModularStepModel composed of an encoder, a decoder, an advance step (dycore corrector + neural physics + stochastic fields) and a forcing module. This notebook pokes at those pieces. It runs entirely from files packaged with the repository (no network needed).

Loading a model#

Pre-trained models are loaded from converted checkpoints with PressureLevelModel.from_checkpoint, as shown in the forecast_quickstart notebook. Here we use the very small (toy) TL63 stochastic model packaged with the repository, converting it on first use:

import dataclasses

import matplotlib.pyplot as plt
import numpy as np
import torch
import xarray

from dinosaur_torch import horizontal_interpolation
from dinosaur_torch import spherical_harmonic
from dinosaur_torch import xarray_utils
import neuralgcm_torch as neuralgcm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
from neuralgcm_torch import pretrained
# Fetched from the Hugging Face Hub on first use (cached), or reused from
# a local checkpoints/ directory if present.
converted_path = pretrained.fetch_checkpoint('tl63_stochastic_mini', local_root='checkpoints')
model = neuralgcm.PressureLevelModel.from_checkpoint(converted_path, device=device)

The data_preparation notebook describes how to prepare data in detail. Here we’ll regrid the single ERA5 snapshot packaged with the repository to the required resolution:

import importlib.resources

with importlib.resources.files('neuralgcm_torch').joinpath(
    'data/era5_tl31_19590102T00.nc'
).open('rb') as f:
  ds = xarray.load_dataset(f).expand_dims('time')

regridder = horizontal_interpolation.ConservativeRegridder(
    spherical_harmonic.GridSpec.TL31(), model.data_grid, device=device
)
ds = xarray_utils.regrid_horizontal(ds, regridder)
inputs, forcings = model.data_from_xarray(ds)
./dinosaur-torch/dinosaur_torch/xarray_utils.py:235: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:213.)
  field = torch.as_tensor(

Encoding & decoding data#

encode transforms input data (on pressure levels) into model variables (on sigma levels). To use it, pass in dictionaries of input and forcing data and, for stochastic models, an integer seed:

encoded = model.encode(inputs, forcings, rng=0)

decode can be used to convert back from model levels to pressure levels:

decoded = model.decode(encoded, forcings)

Encoding/decoding is a lossy process, because NeuralGCM’s encoded model state uses a different coordinate system. For this model, it introduces about 1 degree of error on average near the surface:

float(abs(inputs['temperature'][0, -1] - decoded['temperature'][-1]).mean())
1.2497444152832031

Inside the model#

The wrapped model is an ordinary nn.Module; its parameters are real registered parameters (no parameter trees on the side):

for name, child in model.model.named_children():
  n_params = sum(p.numel() for p in child.parameters())
  print(f'{name:16s} {type(child).__name__:55s} {n_params:>10,} params')
print(f'{"total":74s} {sum(p.numel() for p in model.model.parameters()):>10,} params')
encoder          DimensionalLearnedWeatherbenchToPrimitiveEncoder            56,628 params
decoder          DimensionalLearnedPrimitiveToWeatherbenchDecoder            58,016 params
advance_module   StochasticPhysicsParameterizationStep                       76,530 params
forcing_module   DynamicDataForcing                                               0 params
total                                                                         191,174 params

Standard PyTorch tooling applies: state_dict() for saving/loading, named_parameters() for optimizer param groups, hooks, etc. The advance step is itself composed of a dycore corrector, the neural physics parameterization and the stochastic field module:

for name, child in model.model.advance_module.named_children():
  print(f'{name:28s} {type(child).__name__}')
corrector                    CustomCoordsCorrector
physics_parameterization     DivCurlNeuralParameterization
randomness_module            BatchGaussianRandomFieldModule

Advancing in time#

advance and unroll step the encoded model state forward in time, using a combination of the NeuralGCM dycore and learned physics. advance takes a single time-step of size model.timestep:

assert model.timestep == np.timedelta64(3600, 's')
advanced = model.advance(encoded, forcings)

unroll is the higher-level method for stepping forward multiple steps at once, decoding outputs at a given time interval:

advanced, outputs = model.unroll(
    encoded, forcings, steps=4, timedelta=np.timedelta64(1, 'h')
)
{k: tuple(v.shape) for k, v in outputs.items()}
{'u_component_of_wind': (4, 37, 128, 64),
 'v_component_of_wind': (4, 37, 128, 64),
 'temperature': (4, 37, 128, 64),
 'geopotential': (4, 37, 128, 64),
 'sim_time': (4,),
 'specific_humidity': (4, 37, 128, 64),
 'specific_cloud_ice_water_content': (4, 37, 128, 64),
 'specific_cloud_liquid_water_content': (4, 37, 128, 64)}

At each model time-step, the forcing nearest in time is used, which allows for supplying forcings at coarser time resolution than the model timestep (here a single snapshot, i.e. persistence). The advanced state is the updated encoded state after taking all of the indicated time-steps, which could be fed back into unroll to advance further in time.

For long rollouts, see model.compile(...) (optionally with cudagraphs=True) in the README — it makes each step several times faster after a one-time compilation cost.

Autograd and fine-tuning#

The entire model — encoder, neural physics, dynamical core, decoder — is differentiable end to end with plain autograd. As a demonstration, we compute a rollout loss against a persistence target and look at the gradient magnitudes reaching each component:

from neuralgcm_torch import training

dt = float(model.to_nondim_units(
    model.timestep / np.timedelta64(1, 's'), 's'
))
targets = {k: v.clone() for k, v in inputs.items()}
targets['sim_time'] = inputs['sim_time'] + dt  # persistence, one step later

loss = training.rollout_loss(model, inputs, forcings, targets, rng=0)
loss.backward()
for name, child in model.model.named_children():
  norms = [p.grad.norm() for p in child.parameters() if p.grad is not None]
  total = float(torch.stack(norms).norm()) if norms else 0.0
  print(f'{name:16s} grad norm {total:.3e}')
model.model.zero_grad(set_to_none=True)
encoder          grad norm 3.378e+04
decoder          grad norm 4.054e+04
advance_module   grad norm 3.666e+03
forcing_module   grad norm 0.000e+00

For an actual training loop, data.TrajectoryDataset windows an ERA5-style dataset into (inputs, forcings, targets) examples and training.train_step performs torch.optim updates — see the README. Memory scales with the number of advance steps kept on the autodiff tape, so rollout-training uses short rollouts (1-3 output frames).

Encoded model state#

Warning: there are no stable API guarantees for encoded model state. Expect different models (especially future versions of NeuralGCM) to store different collections of data in different formats.

Model state is a nested dataclass registered as a torch pytree; mapping shapes over it gives a compact view of its structure:

torch.utils._pytree.tree_map(
    lambda x: tuple(x.shape) if hasattr(x, 'shape') else x, encoded
)
ModelState(state=State(vorticity=(32, 128, 65), divergence=(32, 128, 65), temperature_variation=(32, 128, 65), log_surface_pressure=(1, 128, 65), tracers={'specific_humidity': (32, 128, 65), 'specific_cloud_ice_water_content': (32, 128, 65), 'specific_cloud_liquid_water_content': (32, 128, 65)}, sim_time=()), memory=None, diagnostics={}, randomness=RandomnessState(core=(10, 128, 65), nodal_value=(10, 128, 64), modal_value=(10, 128, 65), prng_key=6220207618836018757, prng_step=0))

Internally, model state is represented using a different set of variables (vorticity, divergence, temperature variation and log surface pressure) more appropriate to NeuralGCM’s spectral dynamical core. Each variable is stored in the spherical harmonic basis on sigma levels; converting to velocities on the nodal (lat/lon) grid:

coords = model.model.decoder.coords
nodal_u, nodal_v = coords.horizontal.vor_div_to_uv_nodal(
    encoded.state.vorticity, encoded.state.divergence
)
encoded_ds = xarray.Dataset(
    {
        'u_wind': (('level', 'longitude', 'latitude'), nodal_u.cpu().numpy()),
        'v_wind': (('level', 'longitude', 'latitude'), nodal_v.cpu().numpy()),
    },
    coords={
        'level': coords.vertical.coordinates.centers,  # sigma (NumPy)
        'longitude': model.longitudes,
        'latitude': model.latitudes,
    },
)
encoded_ds.u_wind.sel(level=[0.1, 0.5, 0.9], method='nearest').plot.imshow(
    x='longitude', y='latitude', col='level', aspect=1.6, size=2.3
);
_images/04ec0b700bb55332faa54b54a1549fc87eec1d1b20230afffe884705c17f7dfd.png

You can also work with state in the spherical harmonic (modal) representation, which is sometimes convenient, e.g., for calculating power spectra. Here care should be taken to ensure that structural sparsity is handled properly:

plt.matshow((encoded.state.temperature_variation[0] != 0).cpu());
_images/433a9d4fffd0e20f44a77640e7da1cd7c4a971ad98bde0229bbc553aed33d9ba.png

The non-zero mask can be accessed programmatically:

coords.horizontal.mask
array([[ True,  True,  True, ...,  True,  True,  True],
       [False, False, False, ..., False, False, False],
       [False,  True,  True, ...,  True,  True,  True],
       ...,
       [False, False, False, ...,  True,  True,  True],
       [False, False, False, ..., False,  True,  True],
       [False, False, False, ..., False,  True,  True]], shape=(128, 65))

Random noise#

NeuralGCM stochastic models use deterministic and explicit random number generation. In -next, randomness is seeded by plain integers (internally expanded with splitmix64 into per-draw torch.Generators); random streams are deterministic given the seed, and statistically — not bitwise — equivalent to the original JAX models. Different seeds give different encodings:

encoded = model.encode(inputs, forcings, rng=0)
encoded2 = model.encode(inputs, forcings, rng=1)
# slightly different average temperatures
(float(encoded.state.temperature_variation.mean()),
 float(encoded2.state.temperature_variation.mean()))
(-0.003968254663050175, -0.003972785547375679)

Randomness in the learned physics is controlled via the same rng argument, which seeds the Gaussian random fields carried on the model state. You can visualize them:

xarray.DataArray(
    encoded.randomness.nodal_value.cpu().numpy(),
    dims=['field', 'longitude', 'latitude'],
).plot(x='longitude', y='latitude', col='field', col_wrap=2, aspect=2, size=2);
_images/5331a90d2bbfdd410af91060c346662ea7654abdb58e233b5cb61c1d79b3b80b.png

Adjusting random noise on an existing model state is also possible by replacing randomness.prng_key (an integer in -next); the states will slowly diverge after advancing in time due to new noise being injected at each time step:

encoded_with_new_rng = dataclasses.replace(
    encoded,
    randomness=dataclasses.replace(encoded.randomness, prng_key=123),
)
advanced = model.advance(encoded, forcings)
advanced_with_new_rng = model.advance(encoded_with_new_rng, forcings)
(float(advanced.state.temperature_variation.mean()),
 float(advanced_with_new_rng.state.temperature_variation.mean()))
(-0.004131820518523455, -0.004131820518523455)