1D GMM example#

Optimize bin boundaries for a single discriminant using a Gaussian mixture model. The script builds the toy dataset, trains multiple category counts, and compares GATO-derived significances to equidistant baselines.

Run:

python examples/1D_example/run_gmm_example.py --gato-bins 5 10 20 --epochs 300

Key outputs#

  • Stacked histograms for both equidistant and optimized binning schemes.

  • Loss, boundary and penalty histories saved under examples/1D_example/Plots*/.

  • checkpoints/<N>_bins directories storing model weights for later inspection.

Source code#

Source of the 1D GMM toy example#
  1import argparse
  2from gatohep.data_generation import generate_toy_data_1D
  3from gatohep.losses import (
  4    high_bkg_uncertainty_penalty,
  5    low_bkg_penalty,
  6)
  7from gatohep.models import (
  8    gato_gmm_model,
  9)
 10from gatohep.plotting_utils import (
 11    plot_bias_history,
 12    plot_gmm_1d,
 13    plot_history,
 14    plot_significance_comparison,
 15    plot_stacked_histograms,
 16    plot_yield_vs_uncertainty,
 17)
 18from gatohep.utils import (
 19    LearningRateScheduler,
 20    TemperatureScheduler,
 21    asymptotic_significance,
 22    compute_significance_from_hists,
 23    create_hist,
 24    df_dict_to_tensors,
 25)
 26import hist
 27import numpy as np
 28import tensorflow as tf
 29import tensorflow_probability as tfp
 30import os
 31
 32tfd = tfp.distributions
 33
 34
 35# Define a 1D GATO model for the toy example inheriting from gato_gmm_model
 36class gato_1D(gato_gmm_model):
 37    def __init__(self, n_cats, temperature=1.0, mean_norm="sigmoid"):
 38        super().__init__(
 39            n_cats=n_cats,
 40            dim=1,
 41            temperature=temperature,
 42            mean_norm=mean_norm,
 43            mean_range=(0.0, 1.0),
 44            cov_offdiag_damping=0.1,
 45        )  # dummy NN output is already in (0,1)
 46
 47    def call(self, data_dict):
 48        """
 49        Compute the loss for two signals vs. backgrounds
 50        using the generic helpers from the base class.
 51        """
 52        significances, bkg_yield, bkg_sum_w2 = self.get_differentiable_significance(
 53            data_dict,
 54            signal_labels=["signal"],
 55            return_details=True,
 56        )
 57        loss = -significances["signal"]
 58        return loss, bkg_yield, bkg_sum_w2
 59
 60
 61# main: Generate data, run fixed binning and optimization, then compare.
 62def main():
 63    parser = argparse.ArgumentParser(description="1-D GATO optimisation on toy data")
 64    parser.add_argument(
 65        "--epochs",
 66        type=int,
 67        default=250,
 68        help="number of training epochs (default: 250)",
 69    )
 70
 71    parser.add_argument(
 72        "--gato-bins",
 73        type=int,
 74        nargs="+",
 75        default=[3, 5, 10],
 76        metavar="N",
 77        help="List of target bin counts for the GATO run (default: 3,5,10)",
 78    )
 79
 80    parser.add_argument(
 81        "--lam-yield",
 82        type=float,
 83        default=0.0,
 84        help=r"lambda for the low-background-yield penalty (default: 0)",
 85    )
 86
 87    parser.add_argument(
 88        "--lam-unc",
 89        type=float,
 90        default=0.0,
 91        help=r"lambda for the high-uncertainty penalty (default: 0)",
 92    )
 93
 94    parser.add_argument(
 95        "--thr-yield",
 96        type=float,
 97        default=5.0,
 98        help="Threshold (events) below which the low-yield "
 99        "penalty turns on (default: 10)",
100    )
101
102    parser.add_argument(
103        "--thr-unc",
104        type=float,
105        default=0.20,
106        help="Relative uncertainty above which the uncertainty "
107        "penalty turns on (default: 0.20)",
108    )
109
110    parser.add_argument(
111        "--n-bkg",
112        type=int,
113        default=300000,
114        help="Total number of background events to generate.",
115    )
116
117    parser.add_argument(
118        "--out",
119        type=str,
120        default="Plots",
121        help='Suffix for the output directory. Default: "Plots"',
122    )
123
124    args = parser.parse_args()
125
126    gato_binning_options = args.gato_bins
127    epochs = args.epochs
128    lam_yield = args.lam_yield
129    lam_unc = args.lam_unc
130    yield_thr = args.thr_yield
131    unc_thr = args.thr_unc
132    n_bkg = args.n_bkg
133
134    # 1. Generate toy data
135    data = generate_toy_data_1D(
136        n_signal=100000,
137        n_bkg=n_bkg,
138        seed=42,
139    )
140
141    # Create fixed histograms, will be rebinned afterwards
142    n_bins = 1500
143    low = 0.0
144    high = 1.0
145    hist_signal = create_hist(
146        data["signal"]["NN_output"],
147        weights=data["signal"]["weight"],
148        bins=n_bins,
149        low=low,
150        high=high,
151        name="Signal",
152    )
153    bkg_processes = [f"bkg{i}" for i in range(1, 6)]
154    bkg_hists = [
155        create_hist(
156            data[proc]["NN_output"],
157            weights=data[proc]["weight"],
158            bins=n_bins,
159            low=low,
160            high=high,
161            name=f"{proc.capitalize()}",
162        )
163        for proc in bkg_processes
164    ]
165
166    # plot the backgrounds:
167    process_labels = [f"Background {i}" for i in range(1, len(bkg_processes) + 1)]
168    signal_labels = ["Signal x 100"]
169
170    # For demonstration, we compare multiple binning schemes.
171    equidistant_binning_options = [2, 5, 10, 20]
172    equidistant_significances = {}
173    optimized_significances = {}
174
175    path_plots = f"examples/1D_example/{args.out}/"
176    os.makedirs(path_plots, exist_ok=True)
177    fixed_plot_filename = path_plots + "toy_data.pdf"
178    plot_stacked_histograms(
179        stacked_hists=[bkg_hist[:: hist.rebin(30)] for bkg_hist in bkg_hists],
180        process_labels=process_labels,
181        signal_hists=[hist_signal[:: hist.rebin(30)] * 100],
182        signal_labels=signal_labels,
183        output_filename=fixed_plot_filename,
184        axis_labels=("Toy discriminant", "Events"),
185        normalize=False,
186        log=False,
187    )
188    plot_stacked_histograms(
189        stacked_hists=[bkg_hist[:: hist.rebin(30)] for bkg_hist in bkg_hists],
190        process_labels=process_labels,
191        signal_hists=[hist_signal[:: hist.rebin(30)] * 100],
192        signal_labels=signal_labels,
193        output_filename=fixed_plot_filename.replace(".pdf", "_log.pdf"),
194        axis_labels=("Toy discriminant", "Events"),
195        normalize=False,
196        log=True,
197    )
198
199    # Fixed binning significance
200    for nbins in equidistant_binning_options:
201        nbins_hist = hist_signal.axes[0].size
202        factor = int(nbins_hist / nbins)
203        hist_signal_rb = hist_signal[:: hist.rebin(factor)]
204        bkg_hists_rb = [h[:: hist.rebin(factor)] for h in bkg_hists]
205
206        Z_equidistant = compute_significance_from_hists(hist_signal_rb, bkg_hists_rb)
207        equidistant_significances[nbins] = Z_equidistant
208        print(
209            f"Fixed binning ({nbins} bins): Overall significance = {Z_equidistant:.3f}"
210        )
211
212        fixed_plot_filename = (
213            path_plots + f"NN_output_distribution_fixed_{nbins}bins.pdf"
214        )
215        plot_stacked_histograms(
216            stacked_hists=bkg_hists_rb,
217            process_labels=process_labels,
218            signal_hists=[hist_signal_rb * 100],
219            signal_labels=signal_labels,
220            output_filename=fixed_plot_filename,
221            axis_labels=("Toy NN output", "Toy events"),
222            normalize=False,
223            log=False,
224        )
225        plot_stacked_histograms(
226            stacked_hists=bkg_hists_rb,
227            process_labels=process_labels,
228            signal_hists=[hist_signal_rb * 100],
229            signal_labels=signal_labels,
230            output_filename=fixed_plot_filename.replace(".pdf", "_log.pdf"),
231            axis_labels=("Toy NN output", "Toy events"),
232            normalize=False,
233            log=True,
234        )
235        print(f"Fixed binning ({nbins} bins) plot saved as {fixed_plot_filename}")
236
237    # GATO part
238    tensor_data = df_dict_to_tensors(data)
239    for nbins in gato_binning_options:
240
241        @tf.function
242        def train_step(
243            model,
244            tensor_data,
245            optimizer,
246            lam_yield=0.0,
247            lam_unc=0.0,
248            threshold_yield=5,
249            rel_threshold_unc=0.2,
250        ):
251            with tf.GradientTape() as tape:
252                # assume your call() now returns (loss, B, B_sumsq)
253                loss, B, B_sumw2 = model.call(tensor_data)
254
255                penalty_yield = low_bkg_penalty(B, threshold=threshold_yield)
256                penalty_unc = high_bkg_uncertainty_penalty(
257                    B_sumw2, B, rel_threshold=rel_threshold_unc
258                )
259                total_loss = loss + lam_yield * penalty_yield + lam_unc * penalty_unc
260            grads = tape.gradient(total_loss, model.trainable_variables)
261            optimizer.apply_gradients(zip(grads, model.trainable_variables))
262            return total_loss, loss, penalty_yield, penalty_unc
263
264        # --- Optimization: create a model instance with n_cats = nbins ---
265        model = gato_1D(n_cats=nbins, temperature=1.0)
266
267        optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
268        lr_scheduler = LearningRateScheduler(
269            optimizer,
270            lr_initial=0.1,
271            lr_final=0.001,
272            total_epochs=epochs,
273            mode="cosine",
274        )
275
276        # temperature scheduler
277        temperature_scheduler = TemperatureScheduler(
278            model,
279            t_initial=1.0,
280            t_final=0.05,
281            total_epochs=args.epochs,
282            mode="cosine",
283        )
284
285        loss_history = []
286        penalty_yield_history = []
287        penalty_unc_history = []
288        mean_bias_history = []
289        bias_epochs = []
290        temperature_history = []
291        for epoch in range(epochs):
292            lr_scheduler.update(epoch)
293            temperature_scheduler.update(epoch)
294
295            total_loss, loss, penalty_yield, penalty_unc = train_step(
296                model,
297                tensor_data,
298                optimizer,
299                lam_yield=lam_yield,
300                lam_unc=lam_unc,
301                threshold_yield=yield_thr,
302                rel_threshold_unc=unc_thr,
303            )
304
305            # Save the history
306            loss_history.append(loss.numpy())
307            penalty_yield_history.append(penalty_yield.numpy())
308            penalty_unc_history.append(penalty_unc.numpy())
309            if epoch % 25 == 0 or epoch == epochs - 1:
310                bias_vec = model.get_bias(tensor_data)
311                mean_bias_history.append(float(np.mean(np.abs(bias_vec))))
312                bias_epochs.append(epoch)
313                temperature_history.append(float(model.temperature))
314
315            if epoch % 10 == 0 or epoch == epochs - 1:
316                lr_value = getattr(optimizer, "learning_rate", getattr(optimizer, "lr", None))
317                if hasattr(lr_value, "numpy"):
318                    lr_value = float(lr_value.numpy())
319                else:
320                    lr_value = float(lr_value)
321                print(
322                    f"[n_bins={nbins}] Epoch {epoch}: total_loss = {total_loss.numpy():.3f}, "
323                    f"base_loss = {loss.numpy():.3f}, lr = {lr_value:.5f}"
324                )
325                print("Effective boundaries:", model.get_effective_boundaries_1d())
326        # save the trained GATO model
327        checkpoint_dir = os.path.join(
328            path_plots, "checkpoints", f"{nbins}_bins"
329        )
330        os.makedirs(checkpoint_dir, exist_ok=True)
331        model.save(checkpoint_dir)
332        # Rebuild optimized histograms using effective boundaries
333        eff_boundaries = model.get_effective_boundaries_1d()
334        print(f"Optimized boundaries for {nbins} bins: {eff_boundaries}")
335
336        # check bias due to finite temperature in training
337        bias = model.get_bias(tensor_data)
338        print(f"T = {model.temperature:4.2f};  per-bin bias: {bias}")
339
340        opt_bin_edges = np.concatenate(([low], np.array(eff_boundaries), [high]))
341        h_signal_opt = create_hist(
342            data["signal"]["NN_output"],
343            weights=data["signal"]["weight"],
344            bins=opt_bin_edges,
345            name="Signal_opt",
346        )
347        opt_bkg_hists = [
348            create_hist(
349                data[proc]["NN_output"],
350                weights=data[proc]["weight"],
351                bins=opt_bin_edges,
352                name=f"{proc}_opt",
353            )
354            for proc in bkg_processes
355        ]
356
357        # Compute significance from these optimized histograms.
358        Z_opt = compute_significance_from_hists(h_signal_opt, opt_bkg_hists)
359        optimized_significances[nbins] = Z_opt
360
361        print(f"Optimized binning ({nbins} bins): Overall significance = {Z_opt:.3f}")
362
363        bias_plot_base = path_plots + f"bias_history_{nbins}bins"
364        plot_bias_history(
365            mean_bias_history,
366            bias_plot_base + ".pdf",
367            epochs=bias_epochs,
368            temp_points=temperature_history,
369            temp_label="Temperature",
370        )
371        plot_bias_history(
372            mean_bias_history,
373            bias_plot_base + "_log.pdf",
374            epochs=bias_epochs,
375            temp_points=temperature_history,
376            temp_label="Temperature",
377            log_scale=True,
378        )
379
380        opt_plot_filename = (
381            path_plots + f"NN_output_distribution_optimized_{nbins}bins.pdf"
382        )
383        plot_stacked_histograms(
384            stacked_hists=opt_bkg_hists,
385            process_labels=process_labels,
386            signal_hists=[h_signal_opt * 100],
387            signal_labels=signal_labels,
388            output_filename=opt_plot_filename,
389            axis_labels=("Toy NN output", "Events"),
390            normalize=False,
391            log=False,
392        )
393
394        plot_stacked_histograms(
395            stacked_hists=opt_bkg_hists,
396            process_labels=process_labels,
397            signal_hists=[h_signal_opt * 100],
398            signal_labels=signal_labels,
399            output_filename=opt_plot_filename.replace(".pdf", "_log.pdf"),
400            axis_labels=("Toy NN output", "Events"),
401            normalize=False,
402            log=True,
403        )
404        print(f"Optimized binning ({nbins} bins) plot saved as {opt_plot_filename}")
405
406        # Plot the loss
407        loss_plot_name = path_plots + f"history_loss_{nbins}bins.pdf"
408        plot_history(
409            history_data=loss_history,
410            output_filename=loss_plot_name,
411            y_label="Negative significance",
412            x_label="Epoch",
413            boundaries=False,
414        )
415        regularisation_plot_name = path_plots + f"history_penalty_yield{nbins}bins.pdf"
416        plot_history(
417            history_data=penalty_yield_history,
418            output_filename=regularisation_plot_name,
419            y_label="Low bkg. penalty",
420            x_label="Epoch",
421            boundaries=False,
422        )
423        regularisation_plot_name = path_plots + f"history_penalty_unc_{nbins}bins.pdf"
424        plot_history(
425            history_data=penalty_unc_history,
426            output_filename=regularisation_plot_name,
427            y_label="High bkg. unc. penalty",
428            x_label="Epoch",
429            boundaries=False,
430        )
431
432        B_sorted, rel_unc_sorted, _ = model.compute_hard_bkg_stats(tensor_data)
433        plot_yield_vs_uncertainty(
434            B_sorted,
435            rel_unc_sorted,
436            output_filename=path_plots + f"yield_vs_uncertainty_{nbins}bins_sorted.pdf",
437        )
438        plot_yield_vs_uncertainty(
439            B_sorted,
440            rel_unc_sorted,
441            log=True,
442            output_filename=path_plots
443            + f"yield_vs_uncertainty_{nbins}bins_sorted_log.pdf",
444        )
445        # plot the learned GMM in 1D
446        plot_gmm_1d(
447            model,
448            output_filename=os.path.join(path_plots, f"gmm_components_{nbins}bins.pdf"),
449            x_range=(low, high),
450            n_points=1000,
451        )
452
453    plot_significance_comparison(
454        {"": {nb: equidistant_significances[nb] for nb in equidistant_binning_options}},
455        {"": {nb: optimized_significances[nb] for nb in gato_binning_options}},
456        output_filename=path_plots + "significanceComparison.pdf",
457    )
458
459
460if __name__ == "__main__":
461    main()