1D example based on sigmoids#

Use monotonic sigmoids to approximate bin boundaries in a single discriminant. This variant exposes steepness annealing alongside the usual yield/uncertainty penalties.

Run:

python examples/1D_example/run_sigmoid_example.py --gato-bins 5 10 20 --epochs 300

To regenerate plots and tables from trained checkpoints without re-running the optimization, use:

python examples/analyse_sigmoid_models.py --checkpoint-root examples/1D_example/PlotsSigmoidModel/checkpoints

Outputs mirror the GMM example: diagnostic PDFs, boundary histories, and saved models for each category count.

Source code#

Source of the 1D sigmoid toy example#
  1import argparse
  2from gatohep.data_generation import generate_toy_data_1D
  3from gatohep.losses import (
  4    high_bkg_uncertainty_penalty,
  5    low_bkg_penalty,
  6)
  7from gatohep.models import gato_sigmoid_model
  8from gatohep.plotting_utils import (
  9    plot_bias_history,
 10    plot_history,
 11    plot_significance_comparison,
 12    plot_stacked_histograms,
 13    plot_yield_vs_uncertainty,
 14)
 15from gatohep.utils import (
 16    LearningRateScheduler,
 17    SteepnessScheduler,
 18    asymptotic_significance,
 19    compute_significance_from_hists,
 20    create_hist,
 21    df_dict_to_tensors,
 22)
 23import hist
 24import numpy as np
 25import os
 26import tensorflow as tf
 27import tensorflow_probability as tfp
 28
 29tfd = tfp.distributions
 30
 31
 32# Define the 1D GATO model using sigmoid boundaries
 33class gato_1D(gato_sigmoid_model):
 34    """
 35    One-dimensional GATO model that uses monotonic sigmoid boundaries.
 36
 37    The class replicates the call/significance logic of the 1-D GMM example
 38    so the training script needs only to replace the model import.
 39
 40    Parameters
 41    ----------
 42    n_cats : int
 43        Number of bins (Gaussian components in the GMM analogue).
 44    steepness : float, optional
 45        Initial slope k of the sigmoids.  Can be annealed externally.
 46    """
 47
 48    def __init__(self, n_cats: int, *, steepness: float = 50.0):
 49        variables = [
 50            {"name": "NN_output", "bins": n_cats, "range": (0.0, 1.0)},
 51        ]
 52        super().__init__(variables_config=variables, global_steepness=steepness)
 53
 54    def call(self, data_dict):
 55        """
 56        Compute the loss and background yields for one optimisation step.
 57
 58        The loss is  -sqrt(Z1 * Z2)  where Z1, Z2 are the per-signal
 59        significances obtained with asymptotic formulae.
 60
 61        Parameters
 62        ----------
 63        data_dict : dict
 64            Mapping from process name to a dictionary with keys
 65            ``"NN_output"`` (tensor, shape Nx1) and ``"weight"`` (tensor).
 66
 67        Returns
 68        -------
 69        loss : tf.Tensor  scalar
 70            Negative combined significance (to minimise).
 71        bkg_y : tf.Tensor  (n_cats,)
 72            Background yields per bin.
 73        bkg_w2 : tf.Tensor (n_cats,)
 74            Sum of squared weights per bin for the background.
 75        """
 76        significances, bkg_yield, bkg_sum_w2 = self.get_differentiable_significance(
 77            data_dict,
 78            signal_labels=["signal"],
 79            return_details=True,
 80        )
 81        loss = -significances["signal"]
 82        return loss, bkg_yield, bkg_sum_w2
 83
 84
 85# main: Generate data, run fixed binning and optimization, then compare.
 86def main():
 87    parser = argparse.ArgumentParser(description="1-D GATO optimisation on toy data")
 88    parser.add_argument(
 89        "--epochs",
 90        type=int,
 91        default=250,
 92        help="number of training epochs (default: 250)",
 93    )
 94
 95    parser.add_argument(
 96        "--gato-bins",
 97        type=int,
 98        nargs="+",
 99        default=[3, 5, 10],
100        metavar="N",
101        help="List of target bin counts for the GATO run (default: 3,5,10)",
102    )
103
104    parser.add_argument(
105        "--lam-yield",
106        type=float,
107        default=0.0,
108        help=r"lambda for the low-background-yield penalty (default: 0)",
109    )
110
111    parser.add_argument(
112        "--lam-unc",
113        type=float,
114        default=0.0,
115        help=r"lambda for the high-uncertainty penalty (default: 0)",
116    )
117
118    parser.add_argument(
119        "--thr-yield",
120        type=float,
121        default=5.0,
122        help="Threshold (events) below which the low-yield "
123        "penalty turns on (default: 10)",
124    )
125
126    parser.add_argument(
127        "--thr-unc",
128        type=float,
129        default=0.20,
130        help="Relative uncertainty above which the uncertainty "
131        "penalty turns on (default: 0.20)",
132    )
133
134    parser.add_argument(
135        "--n-bkg",
136        type=int,
137        default=300000,
138        help="Total number of background events to generate.",
139    )
140
141    parser.add_argument(
142        "--out",
143        type=str,
144        default="PlotsSigmoidModel",
145        help='Suffix for the output directory. Default: "Plots"',
146    )
147
148    args = parser.parse_args()
149
150    gato_binning_options = args.gato_bins
151    epochs = args.epochs
152    lam_yield = args.lam_yield
153    lam_unc = args.lam_unc
154    yield_thr = args.thr_yield
155    unc_thr = args.thr_unc
156    n_bkg = args.n_bkg
157
158    # 1. Generate toy data
159    data = generate_toy_data_1D(
160        n_signal=100000,
161        n_bkg=n_bkg,
162        seed=42,
163    )
164
165    # Create fixed histograms, will be rebinned afterwards
166    n_bins = 1500
167    low = 0.0
168    high = 1.0
169    hist_signal = create_hist(
170        data["signal"]["NN_output"],
171        weights=data["signal"]["weight"],
172        bins=n_bins,
173        low=low,
174        high=high,
175        name="Signal",
176    )
177    bkg_processes = [f"bkg{i}" for i in range(1, 6)]
178    bkg_hists = [
179        create_hist(
180            data[proc]["NN_output"],
181            weights=data[proc]["weight"],
182            bins=n_bins,
183            low=low,
184            high=high,
185            name=f"{proc.capitalize()}",
186        )
187        for proc in bkg_processes
188    ]
189
190    # plot the backgrounds:
191    process_labels = [f"Background {i}" for i in range(1, len(bkg_processes) + 1)]
192    signal_labels = ["Signal x 100"]
193
194    # For demonstration, we compare multiple binning schemes.
195    equidistant_binning_options = [2, 5, 10, 20]
196    equidistant_significances = {}
197    optimized_significances = {}
198
199    path_plots = f"examples/1D_example/{args.out}/"
200    os.makedirs(path_plots, exist_ok=True)
201    fixed_plot_filename = path_plots + "toy_data.pdf"
202    plot_stacked_histograms(
203        stacked_hists=[bkg_hist[:: hist.rebin(30)] for bkg_hist in bkg_hists],
204        process_labels=process_labels,
205        signal_hists=[hist_signal[:: hist.rebin(30)] * 100],
206        signal_labels=signal_labels,
207        output_filename=fixed_plot_filename,
208        axis_labels=("Toy discriminant", "Events"),
209        normalize=False,
210        log=False,
211    )
212    plot_stacked_histograms(
213        stacked_hists=[bkg_hist[:: hist.rebin(30)] for bkg_hist in bkg_hists],
214        process_labels=process_labels,
215        signal_hists=[hist_signal[:: hist.rebin(30)] * 100],
216        signal_labels=signal_labels,
217        output_filename=fixed_plot_filename.replace(".pdf", "_log.pdf"),
218        axis_labels=("Toy discriminant", "Events"),
219        normalize=False,
220        log=True,
221    )
222
223    # Fixed binning significance
224    for nbins in equidistant_binning_options:
225        nbins_hist = hist_signal.axes[0].size
226        factor = int(nbins_hist / nbins)
227        hist_signal_rb = hist_signal[:: hist.rebin(factor)]
228        bkg_hists_rb = [h[:: hist.rebin(factor)] for h in bkg_hists]
229
230        Z_equidistant = compute_significance_from_hists(hist_signal_rb, bkg_hists_rb)
231        equidistant_significances[nbins] = Z_equidistant
232        print(
233            f"Fixed binning ({nbins} bins): Overall significance = {Z_equidistant:.3f}"
234        )
235
236        fixed_plot_filename = (
237            path_plots + f"NN_output_distribution_fixed_{nbins}bins.pdf"
238        )
239        plot_stacked_histograms(
240            stacked_hists=bkg_hists_rb,
241            process_labels=process_labels,
242            signal_hists=[hist_signal_rb * 100],
243            signal_labels=signal_labels,
244            output_filename=fixed_plot_filename,
245            axis_labels=("Toy NN output", "Toy events"),
246            normalize=False,
247            log=False,
248        )
249        plot_stacked_histograms(
250            stacked_hists=bkg_hists_rb,
251            process_labels=process_labels,
252            signal_hists=[hist_signal_rb * 100],
253            signal_labels=signal_labels,
254            output_filename=fixed_plot_filename.replace(".pdf", "_log.pdf"),
255            axis_labels=("Toy NN output", "Toy events"),
256            normalize=False,
257            log=True,
258        )
259        print(f"Fixed binning ({nbins} bins) plot saved as {fixed_plot_filename}")
260
261    # GATO part
262    tensor_data = df_dict_to_tensors(data)
263    for nbins in gato_binning_options:
264
265        @tf.function
266        def train_step(
267            model,
268            tensor_data,
269            optimizer,
270            lam_yield=0.0,
271            lam_unc=0.0,
272            threshold_yield=5,
273            rel_threshold_unc=0.2,
274        ):
275            with tf.GradientTape() as tape:
276                # assume your call() now returns (loss, B, B_sumsq)
277                loss, B, B_sumw2 = model.call(tensor_data)
278
279                penalty_yield = low_bkg_penalty(B, threshold=threshold_yield)
280                penalty_unc = high_bkg_uncertainty_penalty(
281                    B_sumw2, B, rel_threshold=rel_threshold_unc
282                )
283                total_loss = loss + lam_yield * penalty_yield + lam_unc * penalty_unc
284            grads = tape.gradient(total_loss, model.trainable_variables)
285            optimizer.apply_gradients(zip(grads, model.trainable_variables))
286            return total_loss, loss, penalty_yield, penalty_unc
287
288        # --- Optimization: create a model instance with n_cats = nbins ---
289        model = gato_1D(n_cats=nbins, steepness=50.0)
290
291        optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
292        lr_scheduler = LearningRateScheduler(
293            optimizer,
294            lr_initial=0.1,
295            lr_final=0.001,
296            total_epochs=epochs,
297            mode="cosine",
298        )
299
300        # steepness scheduler
301        steepness_scheduler = SteepnessScheduler(
302            model,
303            t_initial=50.0,  # starting k
304            t_final=10000.0,  # final k
305            total_epochs=args.epochs,
306            mode="cosine",
307        )
308
309        loss_history = []
310        penalty_yield_history = []
311        penalty_unc_history = []
312        boundary_history = []
313        mean_bias_history = []
314        bias_epochs = []
315        steepness_history = []
316        for epoch in range(epochs):
317            lr_scheduler.update(epoch)
318            steepness_scheduler.update(epoch)  # adjust all k_j in-place
319
320            total_loss, loss, penalty_yield, penalty_unc = train_step(
321                model,
322                tensor_data,
323                optimizer,
324                lam_yield=lam_yield,
325                lam_unc=lam_unc,
326                threshold_yield=yield_thr,
327                rel_threshold_unc=unc_thr,
328            )
329
330            # Save the history
331            loss_history.append(loss.numpy())
332            penalty_yield_history.append(penalty_yield.numpy())
333            penalty_unc_history.append(penalty_unc.numpy())
334            boundary_history.append(
335                model.calculate_boundaries().numpy()
336            )  # list(ndarray)
337            if epoch % 25 == 0 or epoch == epochs - 1:
338                bias_vec = model.get_bias(tensor_data)
339                mean_bias_history.append(float(np.mean(np.abs(bias_vec))))
340                bias_epochs.append(epoch)
341                steepness_history.append(float(model.var_cfg[0]["k"].numpy()))
342
343            if epoch % 10 == 0 or epoch == epochs - 1:
344                lr_value = getattr(optimizer, "learning_rate", getattr(optimizer, "lr", None))
345                if hasattr(lr_value, "numpy"):
346                    lr_value = float(lr_value.numpy())
347                else:
348                    lr_value = float(lr_value)
349                print(
350                    f"[n_bins={nbins}] Epoch {epoch}: total_loss = {total_loss.numpy():.3f}, "
351                    f"base_loss = {loss.numpy():.3f}, lr = {lr_value:.5f}"
352                )
353                print("Effective boundaries:", model.calculate_boundaries())
354        # save the trained GATO model
355        checkpoint_dir = os.path.join(
356            path_plots, "checkpoints", f"{nbins}_bins"
357        )
358        os.makedirs(checkpoint_dir, exist_ok=True)
359        model.save(checkpoint_dir)
360        # Rebuild optimized histograms using effective boundaries
361        eff_boundaries = model.calculate_boundaries()
362        print(f"Optimized boundaries for {nbins} bins: {eff_boundaries}")
363
364        # check bias due to finite temperature in training
365        bias = model.get_bias(tensor_data)
366        print(f"Steepness = {float(model.var_cfg[0]['k'].numpy()):.1f};  per-bin bias: {bias}")
367
368        opt_bin_edges = np.concatenate(([low], np.array(eff_boundaries), [high]))
369        h_signal_opt = create_hist(
370            data["signal"]["NN_output"],
371            weights=data["signal"]["weight"],
372            bins=opt_bin_edges,
373            name="Signal_opt",
374        )
375        opt_bkg_hists = [
376            create_hist(
377                data[proc]["NN_output"],
378                weights=data[proc]["weight"],
379                bins=opt_bin_edges,
380                name=f"{proc}_opt",
381            )
382            for proc in bkg_processes
383        ]
384
385        # Compute significance from these optimized histograms.
386        Z_opt = compute_significance_from_hists(h_signal_opt, opt_bkg_hists)
387        optimized_significances[nbins] = Z_opt
388
389        print(f"Optimized binning ({nbins} bins): Overall significance = {Z_opt:.3f}")
390
391        opt_plot_filename = (
392            path_plots + f"NN_output_distribution_optimized_{nbins}bins.pdf"
393        )
394        plot_stacked_histograms(
395            stacked_hists=opt_bkg_hists,
396            process_labels=process_labels,
397            signal_hists=[h_signal_opt * 100],
398            signal_labels=signal_labels,
399            output_filename=opt_plot_filename,
400            axis_labels=("Toy NN output", "Events"),
401            normalize=False,
402            log=False,
403        )
404
405        plot_stacked_histograms(
406            stacked_hists=opt_bkg_hists,
407            process_labels=process_labels,
408            signal_hists=[h_signal_opt * 100],
409            signal_labels=signal_labels,
410            output_filename=opt_plot_filename.replace(".pdf", "_log.pdf"),
411            axis_labels=("Toy NN output", "Events"),
412            normalize=False,
413            log=True,
414        )
415        print(f"Optimized binning ({nbins} bins) plot saved as {opt_plot_filename}")
416
417        bias_plot_base = path_plots + f"bias_history_{nbins}bins"
418        plot_bias_history(
419            mean_bias_history,
420            bias_plot_base + ".pdf",
421            epochs=bias_epochs,
422            temp_points=steepness_history,
423            temp_label="Steepness",
424        )
425        plot_bias_history(
426            mean_bias_history,
427            bias_plot_base + "_log.pdf",
428            epochs=bias_epochs,
429            temp_points=steepness_history,
430            temp_label="Steepness",
431            log_scale=True,
432        )
433
434        # Plot the loss
435        loss_plot_name = path_plots + f"history_loss_{nbins}bins.pdf"
436        plot_history(
437            history_data=loss_history,
438            output_filename=loss_plot_name,
439            y_label="Negative significance",
440            x_label="Epoch",
441            boundaries=False,
442        )
443        regularisation_plot_name = path_plots + f"history_penalty_yield{nbins}bins.pdf"
444        plot_history(
445            history_data=penalty_yield_history,
446            output_filename=regularisation_plot_name,
447            y_label="Low bkg. penalty",
448            x_label="Epoch",
449            boundaries=False,
450        )
451        regularisation_plot_name = path_plots + f"history_penalty_unc_{nbins}bins.pdf"
452        plot_history(
453            history_data=penalty_unc_history,
454            output_filename=regularisation_plot_name,
455            y_label="High bkg. unc. penalty",
456            x_label="Epoch",
457            boundaries=False,
458        )
459
460        boundary_plot_name = path_plots + f"history_boundaries_{nbins}bins.pdf"
461        plot_history(
462            history_data=boundary_history,
463            output_filename=boundary_plot_name,
464            y_label="Boundary position",
465            x_label="Epoch",
466            boundaries=True,
467        )
468
469        (
470            B,
471            rel_unc,
472        ) = model.compute_hard_bkg_stats(tensor_data)
473        plot_yield_vs_uncertainty(
474            B,
475            rel_unc,
476            output_filename=path_plots + f"yield_vs_uncertainty_{nbins}bins_sorted.pdf",
477        )
478        plot_yield_vs_uncertainty(
479            B,
480            rel_unc,
481            log=True,
482            output_filename=path_plots
483            + f"yield_vs_uncertainty_{nbins}bins_sorted_log.pdf",
484        )
485
486    plot_significance_comparison(
487        {"": {nb: equidistant_significances[nb] for nb in equidistant_binning_options}},
488        {"": {nb: optimized_significances[nb] for nb in gato_binning_options}},
489        output_filename=path_plots + "significanceComparison.pdf",
490    )
491
492
493if __name__ == "__main__":
494    main()