Calibrating a biogeochemical model with EnsembleKalmanProcesses

In this example we calibrate some of the parameters for the NPZD model in a simple box model setup using a data assimilation package EnsembleKalmanProcesses. First we setup the model and generate synthetic data with "true" parameters. We then define priors and setup an EKP to solve.

While this is a very simple situation it illustrates the ease of integration with data assimilation tools. Examples given in the EnsembleKalmanProcesses docs illustrate how the package can be used to solve more complex forward models.

Install dependencies

First we ensure we have the required dependencies installed

using Pkg
pkg "add OceanBioME, Oceananigans, CairoMakie, EnsembleKalmanProcesses, Distributions"
using OceanBioME, EnsembleKalmanProcesses, JLD2, CairoMakie, Oceananigans.Units, Oceananigans
using LinearAlgebra, Random

using Distributions

using EnsembleKalmanProcesses
using EnsembleKalmanProcesses.ParameterDistributions

const year = years = 365day

rng_seed = 41
rng = Random.MersenneTwister(rng_seed)
Random.MersenneTwister(41)

Setup the forward model

@inline PAR⁰(t) = 60 * (1 - cos((t + 15days) * 2π / year)) * (1 / (1 + 0.2 * exp(-((mod(t, year) - 200days) / 50days)^2))) + 2

z = -10 # nominal depth of the box for the PAR profile
@inline PAR(t) = PAR⁰(t) * exp(0.2z) # Modify the PAR based on the nominal depth and exponential decay

function run_box_simulation(initial_photosynthetic_slope,
                            base_maximum_growth,
                            nutrient_half_saturation,
                            phyto_base_mortality_rate,
                            j)

    biogeochemistry = NutrientPhytoplanktonZooplanktonDetritus(; grid = BoxModelGrid(),
                                                                 initial_photosynthetic_slope,
                                                                 base_maximum_growth,
                                                                 nutrient_half_saturation,
                                                                 phyto_base_mortality_rate,
                                                                 light_attenuation_model = nothing)

    model = BoxModel(; biogeochemistry, forcing = (; PAR))

    set!(model, N = 10.0, P = 0.1, Z = 0.01)

    simulation = Simulation(model; Δt = 20minutes, stop_time = 3years, verbose = false)

    simulation.output_writers[:fields] = JLD2OutputWriter(model, model.fields; filename = "box_calibration_$j.jld2", schedule = TimeInterval(8hours), overwrite_existing = true)

    @info "Running the model..."
    run!(simulation)

    P = FieldTimeSeries("box_calibration_$j.jld2", "P")

    times = P.times

    return P[1, 1, 1, length(times)-1092:end], times[length(times)-1092:end]
end
run_box_simulation (generic function with 1 method)

Define the forward map

function G(u, j)
    (initial_photosynthetic_slope,
     base_maximum_growth,
     nutrient_half_saturation,
     phyto_base_mortality_rate) = u

    P, times = run_box_simulation(initial_photosynthetic_slope,
                                  base_maximum_growth,
                                  nutrient_half_saturation,
                                  phyto_base_mortality_rate,
                                  j)

    peak, winter, average, peak_timing, die_off_time = extract_observables(P, times)

    return [peak, winter, average, peak_timing, die_off_time], P
end

function extract_observables(P, times)
    if all(P .> 0) # model failure - including just in case
        peak = maximum(P)
        winter = minimum(P)
        average = mean(P)

        peak_timing = times[findmax(P)[2]]

        growth_rate = diff(P)[546:end]

        die_off_time = times[545 + findmin(growth_rate)[2]]

        return peak, winter, average, peak_timing./day, die_off_time./day
    else
        return NaN, NaN, NaN, NaN, NaN
    end
end
extract_observables (generic function with 1 method)

Generate the "truth" data (normally you would load observations etc here)

Γ = Diagonal([0.001, 0.0001, 0.002, 5., 5.])

noise_dist = MvNormal(zeros(5), Γ)

truth = (0.15/day, 0.7/day, 2.4, 0.01/day)
obs, P₀ = G(truth, 1)

y = obs .+ rand(noise_dist)
5-element Vector{Float64}:
  -0.052016148124423914
   0.013199965829515032
  -0.033848222282833657
 735.0285290912893
 914.3011062563494

Solve the inverse problem and record all of the results for plotting purposes

prior_u1 = constrained_gaussian("initial_photosynthetic_slope", 0.1953 / day, 0.05 / day, 0, Inf)
prior_u2 = constrained_gaussian("base_maximum_growth", 0.6989 / day, 0.1/ day, 0, Inf)
prior_u3 = constrained_gaussian("nutrient_half_saturation", 2.3868, 0.5, 0, Inf)
prior_u4 = constrained_gaussian("phyto_base_mortality_rate", 0.0101 / day, 0.01 / day, 0, Inf)

prior = combine_distributions([prior_u1, prior_u2, prior_u3, prior_u4])

N_ensemble = 8
N_iterations = 5

initial_ensemble = construct_initial_ensemble(rng, prior, N_ensemble)

ensemble_kalman_process = EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng, failure_handler_method = SampleSuccGauss())

P = zeros(1093, N_ensemble, N_iterations) # recording all of the results for plotting only (not essential)

for i in 1:N_iterations
    @info "Iteration: $i"
    params_i = get_ϕ_final(prior, ensemble_kalman_process)

    G_ens = zeros(5, N_ensemble)

    Threads.@threads for j in 1:N_ensemble
        G_ens[:, j], P[:, j, i] = G(params_i[:, j], j)
    end

    update_ensemble!(ensemble_kalman_process, G_ens)
end

final_ensemble = get_ϕ_final(prior, ensemble_kalman_process)
4×8 Matrix{Float64}:
 2.97195e-6  2.3892e-6   1.58669e-6  2.3208e-6   1.63028e-6  3.33651e-6  1.87558e-6  1.78529e-6
 6.83832e-6  7.76094e-6  8.11689e-6  9.79117e-6  7.54454e-6  9.62738e-6  8.14584e-6  8.70185e-6
 2.9536      1.8138      2.78527     2.50635     2.1405      3.17535     2.74784     3.4374
 8.96118e-8  3.42416e-8  3.11699e-8  2.8425e-8   1.7344e-7   6.26795e-8  6.59976e-8  1.03756e-7

Plot the results

fig = Figure()

n = Observable(1)

title = @lift string("Generation: ", $n)

P_plts = [@lift P[:, j, $n] for j in 1:N_ensemble]

fig = Figure(size = (1200, 800));

ax = Axis(fig[1, 1], xlabel = "Day of year", ylabel = "Phytoplankton concentration (mmol/m³)"; title)

[lines!(ax, [1:8hours:365days-16hours;]./day, P_plts[j], color = :black, alpha = 0.2) for j in 1:N_ensemble]

lines!(ax, [1:8hours:365days-16hours;]./day, P₀, color = :black)

record(fig, "data_assimilation.mp4", 1:size(P, 3); framerate = 2) do i; n[] = i; end
┌ Warning: No strict ticks found
└ @ PlotUtils ~/.julia/packages/PlotUtils/8mrSm/src/ticks.jl:194
┌ Warning: No strict ticks found
└ @ PlotUtils ~/.julia/packages/PlotUtils/8mrSm/src/ticks.jl:194


This page was generated using Literate.jl.