Checkpoint modifications#

Checkpoint modifications can be made to include more outputs or change certain aspects of the simulations (e.g., changing filters). Here, we demonstrate how to add an output for surface pressure and how to use a modified filter that fixes the global mean surface pressure (NeuralGCM models conserve moisture but not dry air mass; see the discussion in the NeuralGCM paper).

In the upstream package (and the legacy port) this is done by appending gin-config lines to the checkpoint’s config string. In neuralgcm-torch, converted checkpoints store the parsed config as plain data, so modifications are ordinary dict edits — no gin involved. Mirrors the upstream checkpoint_modifications notebook; needs network access (checkpoint + ERA5, both anonymous GCS).

import copy
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import torch
import xarray

from dinosaur_torch import horizontal_interpolation
from dinosaur_torch import xarray_utils
import neuralgcm_torch as neuralgcm
from neuralgcm_torch import checkpoint as checkpoint_lib

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Load and inspect the checkpoint config#

model_name = 'deterministic_2_8_deg'  #@param ['deterministic_0_7_deg', 'deterministic_1_4_deg', 'deterministic_2_8_deg', 'stochastic_1_4_deg']

from neuralgcm_torch import pretrained
# Fetched from the Hugging Face Hub on first use (cached), or reused from
# a local checkpoints/ directory if present.
converted_path = pretrained.fetch_checkpoint(model_name, local_root='checkpoints')
base_checkpoint = checkpoint_lib.load(converted_path)

The anatomy of a converted checkpoint#

Before editing anything, here is the whole converted checkpoint at a glance — the high-level config (just plain Python data you can edit), the auxiliary feature arrays, and the learned-parameter hierarchy with per-subtree counts. Expand any section to drill in.

# A compact "anatomy" of the converted checkpoint, rendered as a collapsible
# HTML card: high-level stats, the (editable) `config` tree, the auxiliary
# feature arrays, and the learned-parameter hierarchy with per-subtree counts.
import html

from IPython.display import HTML


def _human(n):
    for unit in ('', 'K', 'M', 'B'):
        if abs(n) < 1000:
            return f'{n:.0f}' if unit == '' else f'{n:.2f}{unit}'
        n /= 1000
    return f'{n:.2f}T'


def _leaf(v):
    if isinstance(v, torch.Tensor):
        return f'tensor{tuple(v.shape)}'
    if isinstance(v, np.ndarray):
        return f'array{v.shape}·{v.dtype}'
    if isinstance(v, bool):
        return str(v)
    if isinstance(v, float):
        return f'{v:.4g}'
    if isinstance(v, int):
        return str(v)
    if isinstance(v, (list, tuple)):
        return f'{type(v).__name__}[{len(v)}]'
    if isinstance(v, dict):
        return f'dict·{len(v)}'
    if isinstance(v, str):
        s = v if len(v) <= 48 else v[:45] + '…'
        return f'"{html.escape(s)}"'
    return 'None' if v is None else html.escape(type(v).__name__)


_K = 'color:#1d6fb8;font-weight:600'
_V = 'color:#444;font-family:ui-monospace,monospace;font-size:90%'
_C = 'color:#999'
_PAD = 'margin-left:1.1em;border-left:1px solid #e3e3e3;padding-left:.7em'


def _tree(d, open_depth=1, depth=0, max_depth=2):
    rows = []
    for k, v in d.items():
        key = f'<span style="{_K}">{html.escape(str(k))}</span>'
        if isinstance(v, dict) and v and depth < max_depth:
            op = ' open' if depth < open_depth else ''
            rows.append(
                f'<details{op}><summary>{key} '
                f'<span style="{_C}">·{len(v)}</span></summary>'
                f'{_tree(v, open_depth, depth + 1, max_depth)}</details>')
        else:
            rows.append(f'<div>{key} <span style="{_V}">{_leaf(v)}</span></div>')
    return f'<div style="{_PAD}">{"".join(rows)}</div>'


def _params_tree(params):
    """Nest the flat haiku paths (dropping '~') and sum params per subtree."""
    root, counts = {}, {}
    for path, bundle in params.items():
        parts = [p for p in path.split('/') if p != '~']
        n = sum(int(np.prod(t.shape)) for t in bundle.values())
        node = root
        for i, p in enumerate(parts[:-1]):
            node = node.setdefault(p, {})
            key = '/'.join(parts[:i + 1])
            counts[key] = counts.get(key, 0) + n
        node[parts[-1]] = {kk: tuple(vv.shape) for kk, vv in bundle.items()}
    return root, counts


def _ptree(d, counts, prefix='', depth=0, open_depth=1):
    rows = []
    for k, v in d.items():
        path = f'{prefix}/{k}' if prefix else k
        key = f'<span style="{_K}">{html.escape(str(k))}</span>'
        branch = isinstance(v, dict) and not all(
            isinstance(x, tuple) for x in v.values())
        if branch:
            op = ' open' if depth < open_depth else ''
            rows.append(
                f'<details{op}><summary>{key} <span style="{_C}">· '
                f'{_human(counts.get(path, 0))}</span></summary>'
                f'{_ptree(v, counts, path, depth + 1, open_depth)}</details>')
        else:
            shapes = ', '.join(f'{kk}{vv}' for kk, vv in v.items())
            rows.append(
                f'<div>{key} <span style="{_V}">{html.escape(shapes)}</span></div>')
    return f'<div style="{_PAD}">{"".join(rows)}</div>'


def checkpoint_anatomy(ck, name=''):
    """A collapsible HTML overview of a converted NeuralGCM checkpoint."""
    cfg = ck['config']
    n_par = sum(int(np.prod(t.shape))
                for b in ck['params'].values() for t in b.values())

    def grid(g):
        return (f"{g['longitude_nodes']}×{g['latitude_nodes']} "
                f{g['spherical_harmonics']} SH")

    stats = [
        ('format', f"v{ck['format_version']}"),
        ('model grid', grid(cfg['model_grid'])),
        ('data grid', grid(cfg['data_grid'])),
        ('vertical', f"{len(cfg['model_sigma_boundaries']) - 1} σ-levels "
                     f"→ {len(cfg['data_pressure_levels'])} p-levels"),
        ('timestep', f"{cfg['timestep_seconds']:.0f} s · dt={cfg['dt']:.4f}"),
        ('parameters', f"{_human(n_par)} · {len(ck['params'])} bundles"),
        ('gin bindings', f"{len(cfg['model'])} configurables"),
    ]
    chips = ''.join(
        f'<div style="padding:.25em .8em"><div style="font-size:75%;'
        f'color:#7a8aa0;text-transform:uppercase;letter-spacing:.04em">'
        f'{html.escape(k)}</div><div style="font-weight:600;color:#10212f">'
        f'{html.escape(str(v))}</div></div>' for k, v in stats)

    cfg_main = {k: v for k, v in cfg.items()
                if k not in ('gin_config_str', 'model')}
    model_rows = ''.join(
        f'<div><span style="{_K}">{html.escape(k)}</span> '
        f'<span style="{_V}">{_leaf(v)}</span></div>'
        for k, v in cfg['model'].items())
    model_block = (
        f'<details><summary><span style="{_K}">model</span> <span style="{_C}">'
        f{len(cfg["model"])} gin configurables (edit these)</span></summary>'
        f'<div style="{_PAD}">{model_rows}</div></details>')
    aux_rows = ''.join(
        f'<div><span style="{_K}">{html.escape(k)}</span> '
        f'<span style="{_V}">{_leaf(v)}</span></div>'
        for k, v in ck['aux_features'].items())
    root, counts = _params_tree(ck['params'])

    return HTML(
        '<div style="font-family:system-ui,sans-serif;max-width:760px;'
        'border:1px solid #dfe6ee;border-radius:12px;overflow:hidden">'
        '<div style="background:linear-gradient(95deg,#0f8f93,#19c08a);'
        'color:#fff;padding:.7em 1em;font-weight:700;font-size:108%">'
        f'NeuralGCM checkpoint · {html.escape(name)}</div>'
        '<div style="display:flex;flex-wrap:wrap;gap:.1em;background:#f6f9fb;'
        f'padding:.5em;border-bottom:1px solid #e6edf3">{chips}</div>'
        '<div style="padding:.6em 1em;font-size:92%">'
        '<details open><summary style="font-weight:600">config '
        '<span style="color:#999">(editable — plain data)</span></summary>'
        f'{_tree(cfg_main)}{model_block}</details>'
        '<details><summary style="font-weight:600">aux_features</summary>'
        f'<div style="{_PAD}">{aux_rows}</div></details>'
        '<details><summary style="font-weight:600">params '
        f'<span style="color:#999">· {_human(n_par)}</span></summary>'
        f'{_ptree(root, counts)}</details></div></div>')


checkpoint_anatomy(base_checkpoint, model_name)
NeuralGCM checkpoint · deterministic_2_8_deg
format
v1
model grid
128×64 · real SH
data grid
128×64 · real SH
vertical
32 σ-levels → 37 p-levels
timestep
3600 s · dt=0.5250
parameters
14.52M · 121 bundles
gin bindings
133 configurables
config (editable — plain data)
model_grid ·8
longitude_wavenumbers 64
total_wavenumbers 65
longitude_nodes 128
latitude_nodes 64
latitude_spacing "gauss"
longitude_offset 0
radius 1
spherical_harmonics "real"
data_grid ·8
longitude_wavenumbers 64
total_wavenumbers 65
longitude_nodes 128
latitude_nodes 64
latitude_spacing "gauss"
longitude_offset 0
radius 1
spherical_harmonics "real"
model_sigma_boundaries list[33]
data_pressure_levels list[37]
dt 0.525
timestep_seconds 3600
reference_datetime "1979-01-01T00:00:00"
physics ·7
radius 1
angular_velocity 0.5
gravity_acceleration 72.36
ideal_gas_constant 0.0003323
water_vapor_gas_constant 0.000534
water_vapor_isobaric_heat_capacity 0.002153
kappa 0.2857
scale_si ·4
length_m 6.371e+06
time_s 6857
mass_kg 5.18e+18
temperature_K 1
input_variables list[7]
forcing_variables list[2]
tracer_variables list[3]
data ·2
orography_input_grid ·8
longitude_wavenumbers 64
total_wavenumbers 65
longitude_nodes 128
latitude_nodes 64
latitude_spacing "gauss"
longitude_offset 0
radius 1
spherical_harmonics "real"
covariate_units ·3
geopotential_at_surface "m**2 s**-2"
land_sea_mask "(0 - 1)"
orography "m**2 s**-2"
model · 133 gin configurables (edit these)
decode/ColumnTower dict·3
encode/ColumnTower dict·3
process/ColumnTower dict·3
surface_model_decode/ColumnTower dict·3
surface_model_encode/ColumnTower dict·3
surface_model_process/ColumnTower dict·3
advance/CombinedFeatures dict·5
decoder_model/CombinedFeatures dict·5
embedding_model/CombinedFeatures dict·5
encoder_data/CombinedFeatures dict·5
land_model/CombinedFeatures dict·5
sea_ice_model/CombinedFeatures dict·5
sea_model/CombinedFeatures dict·5
coordinate_system_from_dataset dict·2
CoordinateSystem dict·2
custom_corrds/CoordinateSystem dict·2
CustomCoordsCorrector dict·3
data_to_xarray_with_renaming dict·5
encoder/DataExponentialFilter dict·4
orography/DataExponentialFilter dict·4
DimensionalLearnedPrimitiveToWeatherbenchDecoder dict·9
DimensionalLearnedWeatherbenchToPrimitiveWithMemoryEncoder dict·10
DivCurlNeuralParameterization dict·6
DycoreWithPhysicsCorrector dict·6
DynamicDataForcing dict·6
advance/EmbeddingSurfaceFeatures dict·4
advance/EmbeddingVolumeFeatures dict·4
EncoderCombinedTransform dict·2
EncoderFilterTransform dict·2
EpdTower dict·9
surface_model/EpdTower dict·9
dycore/ExponentialFilter dict·4
stability/ExponentialFilter dict·4
FilteredCustomOrography dict·4
with_grads/FloatDataFeatures dict·5
without_grads/FloatDataFeatures dict·4
sea_ice_model/ForcingFeatures dict·2
sea_model/ForcingFeatures dict·2
gelu dict·1
get_model_specs dict·4
get_physics_specs dict·1
smaller/GridWithWavenumbers dict·5
advance/IdentityTransform dict·1
land_model/IdentityTransform dict·1
sea_ice_model/IdentityTransform dict·1
sea_model/IdentityTransform dict·1
InputClipTransform dict·2
advance/InverseLevelScale dict·3
decoder_model/InverseLevelScale dict·3
encoder_data/InverseLevelScale dict·3
decoder/InverseShiftAndNormalize dict·4
div_curl_tendency_outputs/InverseShiftAndNormalize dict·4
encoder/InverseShiftAndNormalize dict·4
advance/LatitudeFeatures dict·1
decoder_model/LatitudeFeatures dict·1
encoder_data/LatitudeFeatures dict·1
LearnedOrography dict·3
advance/LearnedPositionalFeatures dict·3
decoder_model/LearnedPositionalFeatures dict·3
encoder_data/LearnedPositionalFeatures dict·3
decoder/LevelScale dict·3
div_curl_tendency_outputs/LevelScale dict·3
encode/LevelScale dict·3
advance/MemoryVelocityAndValues dict·2
decode/MlpUniform dict·10
encode/MlpUniform dict·10
process/MlpUniform dict·10
surface_model_decode/MlpUniform dict·10
surface_model_encode/MlpUniform dict·10
surface_model_process/MlpUniform dict·10
advance/ModalToNodalEmbedding dict·4
land_model/ModalToNodalEmbedding dict·4
sea_ice_model/ModalToNodalEmbedding dict·4
sea_model/ModalToNodalEmbedding dict·4
MoistPrimitiveEquationsWithCloudMoisture dict·3
advance/NodalLandSeaIceEmbedding dict·5
NodalMapping dict·2
land_model/NodalMapping dict·2
sea_ice_model/NodalMapping dict·2
sea_model/NodalMapping dict·2
NodalVolumeMapping dict·2
NullFeatures dict·1
advance/PressureFeatures dict·1
embedding_model/PressureFeatures dict·1
primitive_eq_specs_constructor dict·1
advance/RadiationFeatures dict·1
decoder_model/RadiationFeatures dict·1
encoder_data/RadiationFeatures dict·1
dycore/SequentialStepFilter dict·2
ml/SequentialStepFilter dict·2
advance/SequentialTransform dict·2
decoder/SequentialTransform dict·2
decoder_model/SequentialTransform dict·2
div_curl_tendency_outputs/SequentialTransform dict·2
encode/SequentialTransform dict·2
encoder_data/SequentialTransform dict·2
advance/ShiftAndNormalize dict·5
decoder_model/ShiftAndNormalize dict·5
embedding_model/ShiftAndNormalize dict·5
encoder_data/ShiftAndNormalize dict·5
land_model/ShiftAndNormalize dict·5
sea_ice_model/ShiftAndNormalize dict·5
sea_model/ShiftAndNormalize dict·5
SigmaCoordinatesEquidistant dict·1
custom_corrds/SigmaCoordinatesEquidistant dict·1
advance/SoftClip dict·3
StochasticModularStepModel dict·5
StochasticPhysicsParameterizationStep dict·6
land_model/TakeSurfaceAdjacentSigmaLevel dict·1
sea_ice_model/TakeSurfaceAdjacentSigmaLevel dict·1
advance/ToModalDiffOperators dict·1
decoder_model/ToModalDiffOperators dict·1
encoder_data/ToModalDiffOperators dict·1
with_grads/ToModalDiffOperators dict·1
trajectory_from_step dict·3
advance/TruncateSigmaLevels dict·2
decoder_model/TruncateSigmaLevels dict·2
advance/VelocityAndPrognostics dict·3
decoder_model/VelocityAndPrognostics dict·3
embedding_model/VelocityAndPrognostics dict·2
encoder_data/VelocityAndPrognostics dict·3
land_model/VelocityAndPrognostics dict·3
sea_ice_model/VelocityAndPrognostics dict·3
VerticalConvTower dict·7
WhirlModel dict·3
xarray_to_data_with_renaming dict·2
xarray_to_dynamic_covariate_data dict·1
xarray_to_state_and_dynamic_covariate_data dict·3
xarray_to_weatherbench_data dict·2
ZerosRandomField dict·1
GridTL63 dict·1
GridTL127 dict·1
GridTL255 dict·1
aux_features
ref_temperatures array(32,)·float64
nodal_orography_m array(128, 64)·float64
land_sea_mask array(128, 64)·float64
covariate_geopotential_at_surface array(128, 64)·float64
covariate_land_sea_mask array(128, 64)·float64
covariate_orography array(128, 64)·float64
params · 14.52M
stochastic_modular_step_model · 14.52M
dimensional_learned_primitive_to_weatherbench_decoder · 3.42M
nodal_mapping · 3.36M
epd_tower · 3.36M
decode_tower · 99.46K
mlp_uniform · 99.46K
linear_0 w(384, 259)
encode_tower · 301.44K
mlp_uniform · 301.44K
linear_0 b(384,), w(784, 384)
process_tower · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_1 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_2 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_3 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_4 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
combined_features · 65.54K
learned_positional_features learned_positional_features(8, 128, 64)
dimensional_learned_weatherbench_to_primitive_with_memory_encoder · 7.01M
learned_weatherbench_to_primitive_encoder · 3.51M
nodal_mapping · 3.44M
epd_tower · 3.44M
decode_tower · 74.11K
mlp_uniform · 74.11K
linear_0 w(384, 193)
encode_tower · 404.35K
mlp_uniform · 404.35K
linear_0 b(384,), w(1052, 384)
process_tower · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_1 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_2 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_3 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_4 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
combined_features · 65.54K
learned_positional_features learned_positional_features(8, 128, 64)
learned_orography orography(4223,)
learned_weatherbench_to_primitive_encoder_1 · 3.51M
nodal_mapping · 3.44M
epd_tower · 3.44M
decode_tower · 74.11K
mlp_uniform · 74.11K
linear_0 w(384, 193)
encode_tower · 404.35K
mlp_uniform · 404.35K
linear_0 b(384,), w(1052, 384)
process_tower · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_1 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_2 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_3 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_4 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
combined_features · 65.54K
learned_positional_features learned_positional_features(8, 128, 64)
learned_orography orography(4223,)
stochastic_physics_parameterization_step · 4.08M
custom_coords_corrector · 4.09K
dycore_with_physics_corrector · 4.09K
learned_orography orography(4094,)
div_curl_neural_parameterization · 4.08M
nodal_mapping · 3.94M
epd_tower · 3.94M
decode_tower · 73.73K
mlp_uniform · 73.73K
linear_0 w(384, 192)
encode_tower · 908.54K
mlp_uniform · 908.54K
linear_0 b(384,), w(2365, 384)
process_tower · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_1 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_2 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_3 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
process_tower_4 · 591.36K
mlp_uniform · 591.36K
linear_0 b(384,), w(384, 384)
linear_1 b(384,), w(384, 384)
linear_2 b(384,), w(384, 384)
linear_3 b(384,), w(384, 384)
combined_features · 141.74K
embedding_surface_features · 1.35K
nodal_land_sea_ice_embedding · 1.35K
modal_to_nodal_embedding · 456
nodal_mapping · 456
epd_tower · 456
surface_model_decode_tower · 128
mlp_uniform · 128
linear_0 w(8, 8)
linear_1 w(8, 8)
surface_model_encode_tower · 40
mlp_uniform · 40
linear_0 b(8,), w(4, 8)
surface_model_process_tower · 288
mlp_uniform · 288
linear_0 b(8,), w(8, 8)
linear_1 b(8,), w(8, 8)
linear_2 b(8,), w(8, 8)
linear_3 b(8,), w(8, 8)
modal_to_nodal_embedding_1 · 432
nodal_mapping · 432
epd_tower · 432
surface_model_decode_tower · 128
mlp_uniform · 128
linear_0 w(8, 8)
linear_1 w(8, 8)
surface_model_encode_tower · 16
mlp_uniform · 16
linear_0 b(8,), w(1, 8)
surface_model_process_tower · 288
mlp_uniform · 288
linear_0 b(8,), w(8, 8)
linear_1 b(8,), w(8, 8)
linear_2 b(8,), w(8, 8)
linear_3 b(8,), w(8, 8)
modal_to_nodal_embedding_2 · 464
nodal_mapping · 464
epd_tower · 464
surface_model_decode_tower · 128
mlp_uniform · 128
linear_0 w(8, 8)
linear_1 w(8, 8)
surface_model_encode_tower · 48
mlp_uniform · 48
linear_0 b(8,), w(5, 8)
surface_model_process_tower · 288
mlp_uniform · 288
linear_0 b(8,), w(8, 8)
linear_1 b(8,), w(8, 8)
linear_2 b(8,), w(8, 8)
linear_3 b(8,), w(8, 8)
embedding_volume_features · 74.85K
modal_to_nodal_embedding · 74.85K
nodal_volume_mapping · 74.85K
vertical_conv_tower · 74.85K
conv_level b(64, 1), w(5, 9, 64)
conv_level_1 b(64, 1), w(5, 64, 64)
conv_level_2 b(64, 1), w(5, 64, 64)
conv_level_3 b(64, 1), w(5, 64, 64)
conv_level_4 b(32, 1), w(5, 64, 32)
learned_positional_features learned_positional_features(8, 128, 64)

The converted checkpoint’s config['model'] holds the parsed bindings as plain data: a dict of 'scope/ClassName' -> {parameter: value} with references encoded as {'__ref__': name}. The pieces we are going to modify:

bindings = base_checkpoint['config']['model']
print(bindings['dycore/SequentialStepFilter'])
print(bindings['DimensionalLearnedPrimitiveToWeatherbenchDecoder']
      ['inputs_to_units_mapping'])
{'filter_modules': [{'__ref__': 'dycore/ExponentialFilter', '__call__': False}, {'__ref__': 'stability/ExponentialFilter', '__call__': False}], 'name': None}
{'sim_time': 'dimensionless', 't': 'kelvin', 'tracers': {'specific_cloud_ice_water_content': 'dimensionless', 'specific_cloud_liquid_water_content': 'dimensionless', 'specific_humidity': 'dimensionless'}, 'u': 'meter / second', 'v': 'meter / second', 'z': 'm**2 s**-2'}

Modify the checkpoint#

Three edits add a surface pressure output:

  1. a step diagnostics module (SurfacePressureDiagnostics) that computes surface pressure from the model state at every step,

  2. a decoder diagnostics module (NodalModelDiagnosticsDecoder) that passes it into the decoded outputs,

  3. an entry in the decoder’s units mapping so the output is converted back to SI units (Pa).

A fourth, optional edit appends FixGlobalMeanFilter to the dycore step filters, which holds the global mean log surface pressure exactly constant:

def with_surface_pressure_output(checkpoint, fix_global_mean):
  checkpoint = copy.deepcopy(checkpoint)
  bindings = checkpoint['config']['model']
  ref = lambda name: {'__ref__': name, '__call__': False}

  decoder = 'DimensionalLearnedPrimitiveToWeatherbenchDecoder'
  units = dict(bindings[decoder]['inputs_to_units_mapping'])
  units['diagnostics'] = {'surface_pressure': 'kg / (meter s**2)'}
  bindings[decoder]['inputs_to_units_mapping'] = units
  bindings[decoder]['diagnostics_module'] = ref('NodalModelDiagnosticsDecoder')
  bindings['StochasticPhysicsParameterizationStep']['diagnostics_module'] = (
      ref('SurfacePressureDiagnostics')
  )
  if fix_global_mean:
    bindings['dycore/SequentialStepFilter']['filter_modules'] = list(
        bindings['dycore/SequentialStepFilter']['filter_modules']
    ) + [ref('surface_pressure/FixGlobalMeanFilter')]
  return checkpoint


model_fix = neuralgcm.PressureLevelModel.from_checkpoint(
    with_surface_pressure_output(base_checkpoint, fix_global_mean=True),
    device=device,
)
model_no_fix = neuralgcm.PressureLevelModel.from_checkpoint(
    with_surface_pressure_output(base_checkpoint, fix_global_mean=False),
    device=device,
)

Initial condition from ERA5#

era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'
full_era5 = xarray.open_zarr(
    era5_path, chunks=None, storage_options=dict(token='anon')
)
variables = model_fix.input_variables + model_fix.forcing_variables
sliced_era5 = full_era5[variables].sel(time='2020-02-14T00').compute()

era5_grid = xarray_utils.grid_spec_from_dataset(full_era5)
regridder = horizontal_interpolation.ConservativeRegridder(
    era5_grid, model_fix.data_grid, skipna=True, device=device
)
eval_era5 = xarray_utils.regrid_horizontal(sliced_era5, regridder)
eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5)

Run both models for a month#

A 30-day integration with daily outputs, using persistent forcings. surface_pressure now appears among the outputs:

outer_steps = 30
timedelta = np.timedelta64(24, 'h')
times = np.arange(outer_steps)  # time axis in days


def month_long_run(model):
  inputs = model.inputs_from_xarray(eval_era5)
  forcings = model.forcings_from_xarray(eval_era5)
  state = model.encode(inputs, forcings, rng=42)
  _, predictions = model.unroll(
      state, forcings, steps=outer_steps, timedelta=timedelta,
      start_with_input=True,
  )
  return model.data_to_xarray(predictions, times=times)


predictions_fix = month_long_run(model_fix)
predictions_no_fix = month_long_run(model_no_fix)
list(predictions_fix.data_vars)
['u_component_of_wind',
 'v_component_of_wind',
 'temperature',
 'geopotential',
 'specific_humidity',
 'specific_cloud_ice_water_content',
 'specific_cloud_liquid_water_content',
 'surface_pressure']

Compare global mean log surface pressure#

The area-weighted global mean of log surface pressure (proportional to the (0, 0) spherical-harmonic coefficient that FixGlobalMeanFilter pins) stays constant in the modified model. Without the filter it is not held fixed and drifts slightly over the month (here by about 0.05%) — a small numerical effect that the filter removes:

weights = xarray.DataArray(
    model_fix.data_grid.quadrature_weights, dims='latitude'
)


def global_mean_log_sp(predictions):
  return np.log(predictions.surface_pressure).weighted(weights).mean(
      ['longitude', 'latitude']
  )


plt.plot(times, global_mean_log_sp(predictions_fix), label='Fixed global mean')
plt.plot(times, global_mean_log_sp(predictions_no_fix), label='No fix')
plt.xlabel('forecast day')
plt.ylabel('global mean log surface pressure')
plt.legend();
_images/603aaa5c14164b3e4762b833a45f1786c4896edf927d1009eeda7253b8e299a8.png