1D example based on sigmoids#
Use monotonic sigmoids to approximate bin boundaries in a single discriminant. This variant exposes steepness annealing alongside the usual yield/uncertainty penalties.
Run:
python examples/1D_example/run_sigmoid_example.py --gato-bins 5 10 20 --epochs 300
To regenerate plots and tables from trained checkpoints without re-running the optimization, use:
python examples/analyse_sigmoid_models.py --checkpoint-root examples/1D_example/PlotsSigmoidModel/checkpoints
Outputs mirror the GMM example: diagnostic PDFs, boundary histories, and saved models for each category count.
Source code#
Source of the 1D sigmoid toy example#
1import argparse
2from gatohep.data_generation import generate_toy_data_1D
3from gatohep.losses import (
4 high_bkg_uncertainty_penalty,
5 low_bkg_penalty,
6)
7from gatohep.models import gato_sigmoid_model
8from gatohep.plotting_utils import (
9 plot_bias_history,
10 plot_history,
11 plot_significance_comparison,
12 plot_stacked_histograms,
13 plot_yield_vs_uncertainty,
14)
15from gatohep.utils import (
16 LearningRateScheduler,
17 SteepnessScheduler,
18 asymptotic_significance,
19 compute_significance_from_hists,
20 create_hist,
21 df_dict_to_tensors,
22)
23import hist
24import numpy as np
25import os
26import tensorflow as tf
27import tensorflow_probability as tfp
28
29tfd = tfp.distributions
30
31
32# Define the 1D GATO model using sigmoid boundaries
33class gato_1D(gato_sigmoid_model):
34 """
35 One-dimensional GATO model that uses monotonic sigmoid boundaries.
36
37 The class replicates the call/significance logic of the 1-D GMM example
38 so the training script needs only to replace the model import.
39
40 Parameters
41 ----------
42 n_cats : int
43 Number of bins (Gaussian components in the GMM analogue).
44 steepness : float, optional
45 Initial slope k of the sigmoids. Can be annealed externally.
46 """
47
48 def __init__(self, n_cats: int, *, steepness: float = 50.0):
49 variables = [
50 {"name": "NN_output", "bins": n_cats, "range": (0.0, 1.0)},
51 ]
52 super().__init__(variables_config=variables, global_steepness=steepness)
53
54 def call(self, data_dict):
55 """
56 Compute the loss and background yields for one optimisation step.
57
58 The loss is -sqrt(Z1 * Z2) where Z1, Z2 are the per-signal
59 significances obtained with asymptotic formulae.
60
61 Parameters
62 ----------
63 data_dict : dict
64 Mapping from process name to a dictionary with keys
65 ``"NN_output"`` (tensor, shape Nx1) and ``"weight"`` (tensor).
66
67 Returns
68 -------
69 loss : tf.Tensor scalar
70 Negative combined significance (to minimise).
71 bkg_y : tf.Tensor (n_cats,)
72 Background yields per bin.
73 bkg_w2 : tf.Tensor (n_cats,)
74 Sum of squared weights per bin for the background.
75 """
76 significances, bkg_yield, bkg_sum_w2 = self.get_differentiable_significance(
77 data_dict,
78 signal_labels=["signal"],
79 return_details=True,
80 )
81 loss = -significances["signal"]
82 return loss, bkg_yield, bkg_sum_w2
83
84
85# main: Generate data, run fixed binning and optimization, then compare.
86def main():
87 parser = argparse.ArgumentParser(description="1-D GATO optimisation on toy data")
88 parser.add_argument(
89 "--epochs",
90 type=int,
91 default=250,
92 help="number of training epochs (default: 250)",
93 )
94
95 parser.add_argument(
96 "--gato-bins",
97 type=int,
98 nargs="+",
99 default=[3, 5, 10],
100 metavar="N",
101 help="List of target bin counts for the GATO run (default: 3,5,10)",
102 )
103
104 parser.add_argument(
105 "--lam-yield",
106 type=float,
107 default=0.0,
108 help=r"lambda for the low-background-yield penalty (default: 0)",
109 )
110
111 parser.add_argument(
112 "--lam-unc",
113 type=float,
114 default=0.0,
115 help=r"lambda for the high-uncertainty penalty (default: 0)",
116 )
117
118 parser.add_argument(
119 "--thr-yield",
120 type=float,
121 default=5.0,
122 help="Threshold (events) below which the low-yield "
123 "penalty turns on (default: 10)",
124 )
125
126 parser.add_argument(
127 "--thr-unc",
128 type=float,
129 default=0.20,
130 help="Relative uncertainty above which the uncertainty "
131 "penalty turns on (default: 0.20)",
132 )
133
134 parser.add_argument(
135 "--n-bkg",
136 type=int,
137 default=300000,
138 help="Total number of background events to generate.",
139 )
140
141 parser.add_argument(
142 "--out",
143 type=str,
144 default="PlotsSigmoidModel",
145 help='Suffix for the output directory. Default: "Plots"',
146 )
147
148 args = parser.parse_args()
149
150 gato_binning_options = args.gato_bins
151 epochs = args.epochs
152 lam_yield = args.lam_yield
153 lam_unc = args.lam_unc
154 yield_thr = args.thr_yield
155 unc_thr = args.thr_unc
156 n_bkg = args.n_bkg
157
158 # 1. Generate toy data
159 data = generate_toy_data_1D(
160 n_signal=100000,
161 n_bkg=n_bkg,
162 seed=42,
163 )
164
165 # Create fixed histograms, will be rebinned afterwards
166 n_bins = 1500
167 low = 0.0
168 high = 1.0
169 hist_signal = create_hist(
170 data["signal"]["NN_output"],
171 weights=data["signal"]["weight"],
172 bins=n_bins,
173 low=low,
174 high=high,
175 name="Signal",
176 )
177 bkg_processes = [f"bkg{i}" for i in range(1, 6)]
178 bkg_hists = [
179 create_hist(
180 data[proc]["NN_output"],
181 weights=data[proc]["weight"],
182 bins=n_bins,
183 low=low,
184 high=high,
185 name=f"{proc.capitalize()}",
186 )
187 for proc in bkg_processes
188 ]
189
190 # plot the backgrounds:
191 process_labels = [f"Background {i}" for i in range(1, len(bkg_processes) + 1)]
192 signal_labels = ["Signal x 100"]
193
194 # For demonstration, we compare multiple binning schemes.
195 equidistant_binning_options = [2, 5, 10, 20]
196 equidistant_significances = {}
197 optimized_significances = {}
198
199 path_plots = f"examples/1D_example/{args.out}/"
200 os.makedirs(path_plots, exist_ok=True)
201 fixed_plot_filename = path_plots + "toy_data.pdf"
202 plot_stacked_histograms(
203 stacked_hists=[bkg_hist[:: hist.rebin(30)] for bkg_hist in bkg_hists],
204 process_labels=process_labels,
205 signal_hists=[hist_signal[:: hist.rebin(30)] * 100],
206 signal_labels=signal_labels,
207 output_filename=fixed_plot_filename,
208 axis_labels=("Toy discriminant", "Events"),
209 normalize=False,
210 log=False,
211 )
212 plot_stacked_histograms(
213 stacked_hists=[bkg_hist[:: hist.rebin(30)] for bkg_hist in bkg_hists],
214 process_labels=process_labels,
215 signal_hists=[hist_signal[:: hist.rebin(30)] * 100],
216 signal_labels=signal_labels,
217 output_filename=fixed_plot_filename.replace(".pdf", "_log.pdf"),
218 axis_labels=("Toy discriminant", "Events"),
219 normalize=False,
220 log=True,
221 )
222
223 # Fixed binning significance
224 for nbins in equidistant_binning_options:
225 nbins_hist = hist_signal.axes[0].size
226 factor = int(nbins_hist / nbins)
227 hist_signal_rb = hist_signal[:: hist.rebin(factor)]
228 bkg_hists_rb = [h[:: hist.rebin(factor)] for h in bkg_hists]
229
230 Z_equidistant = compute_significance_from_hists(hist_signal_rb, bkg_hists_rb)
231 equidistant_significances[nbins] = Z_equidistant
232 print(
233 f"Fixed binning ({nbins} bins): Overall significance = {Z_equidistant:.3f}"
234 )
235
236 fixed_plot_filename = (
237 path_plots + f"NN_output_distribution_fixed_{nbins}bins.pdf"
238 )
239 plot_stacked_histograms(
240 stacked_hists=bkg_hists_rb,
241 process_labels=process_labels,
242 signal_hists=[hist_signal_rb * 100],
243 signal_labels=signal_labels,
244 output_filename=fixed_plot_filename,
245 axis_labels=("Toy NN output", "Toy events"),
246 normalize=False,
247 log=False,
248 )
249 plot_stacked_histograms(
250 stacked_hists=bkg_hists_rb,
251 process_labels=process_labels,
252 signal_hists=[hist_signal_rb * 100],
253 signal_labels=signal_labels,
254 output_filename=fixed_plot_filename.replace(".pdf", "_log.pdf"),
255 axis_labels=("Toy NN output", "Toy events"),
256 normalize=False,
257 log=True,
258 )
259 print(f"Fixed binning ({nbins} bins) plot saved as {fixed_plot_filename}")
260
261 # GATO part
262 tensor_data = df_dict_to_tensors(data)
263 for nbins in gato_binning_options:
264
265 @tf.function
266 def train_step(
267 model,
268 tensor_data,
269 optimizer,
270 lam_yield=0.0,
271 lam_unc=0.0,
272 threshold_yield=5,
273 rel_threshold_unc=0.2,
274 ):
275 with tf.GradientTape() as tape:
276 # assume your call() now returns (loss, B, B_sumsq)
277 loss, B, B_sumw2 = model.call(tensor_data)
278
279 penalty_yield = low_bkg_penalty(B, threshold=threshold_yield)
280 penalty_unc = high_bkg_uncertainty_penalty(
281 B_sumw2, B, rel_threshold=rel_threshold_unc
282 )
283 total_loss = loss + lam_yield * penalty_yield + lam_unc * penalty_unc
284 grads = tape.gradient(total_loss, model.trainable_variables)
285 optimizer.apply_gradients(zip(grads, model.trainable_variables))
286 return total_loss, loss, penalty_yield, penalty_unc
287
288 # --- Optimization: create a model instance with n_cats = nbins ---
289 model = gato_1D(n_cats=nbins, steepness=50.0)
290
291 optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
292 lr_scheduler = LearningRateScheduler(
293 optimizer,
294 lr_initial=0.1,
295 lr_final=0.001,
296 total_epochs=epochs,
297 mode="cosine",
298 )
299
300 # steepness scheduler
301 steepness_scheduler = SteepnessScheduler(
302 model,
303 t_initial=50.0, # starting k
304 t_final=10000.0, # final k
305 total_epochs=args.epochs,
306 mode="cosine",
307 )
308
309 loss_history = []
310 penalty_yield_history = []
311 penalty_unc_history = []
312 boundary_history = []
313 mean_bias_history = []
314 bias_epochs = []
315 steepness_history = []
316 for epoch in range(epochs):
317 lr_scheduler.update(epoch)
318 steepness_scheduler.update(epoch) # adjust all k_j in-place
319
320 total_loss, loss, penalty_yield, penalty_unc = train_step(
321 model,
322 tensor_data,
323 optimizer,
324 lam_yield=lam_yield,
325 lam_unc=lam_unc,
326 threshold_yield=yield_thr,
327 rel_threshold_unc=unc_thr,
328 )
329
330 # Save the history
331 loss_history.append(loss.numpy())
332 penalty_yield_history.append(penalty_yield.numpy())
333 penalty_unc_history.append(penalty_unc.numpy())
334 boundary_history.append(
335 model.calculate_boundaries().numpy()
336 ) # list(ndarray)
337 if epoch % 25 == 0 or epoch == epochs - 1:
338 bias_vec = model.get_bias(tensor_data)
339 mean_bias_history.append(float(np.mean(np.abs(bias_vec))))
340 bias_epochs.append(epoch)
341 steepness_history.append(float(model.var_cfg[0]["k"].numpy()))
342
343 if epoch % 10 == 0 or epoch == epochs - 1:
344 lr_value = getattr(optimizer, "learning_rate", getattr(optimizer, "lr", None))
345 if hasattr(lr_value, "numpy"):
346 lr_value = float(lr_value.numpy())
347 else:
348 lr_value = float(lr_value)
349 print(
350 f"[n_bins={nbins}] Epoch {epoch}: total_loss = {total_loss.numpy():.3f}, "
351 f"base_loss = {loss.numpy():.3f}, lr = {lr_value:.5f}"
352 )
353 print("Effective boundaries:", model.calculate_boundaries())
354 # save the trained GATO model
355 checkpoint_dir = os.path.join(
356 path_plots, "checkpoints", f"{nbins}_bins"
357 )
358 os.makedirs(checkpoint_dir, exist_ok=True)
359 model.save(checkpoint_dir)
360 # Rebuild optimized histograms using effective boundaries
361 eff_boundaries = model.calculate_boundaries()
362 print(f"Optimized boundaries for {nbins} bins: {eff_boundaries}")
363
364 # check bias due to finite temperature in training
365 bias = model.get_bias(tensor_data)
366 print(f"Steepness = {float(model.var_cfg[0]['k'].numpy()):.1f}; per-bin bias: {bias}")
367
368 opt_bin_edges = np.concatenate(([low], np.array(eff_boundaries), [high]))
369 h_signal_opt = create_hist(
370 data["signal"]["NN_output"],
371 weights=data["signal"]["weight"],
372 bins=opt_bin_edges,
373 name="Signal_opt",
374 )
375 opt_bkg_hists = [
376 create_hist(
377 data[proc]["NN_output"],
378 weights=data[proc]["weight"],
379 bins=opt_bin_edges,
380 name=f"{proc}_opt",
381 )
382 for proc in bkg_processes
383 ]
384
385 # Compute significance from these optimized histograms.
386 Z_opt = compute_significance_from_hists(h_signal_opt, opt_bkg_hists)
387 optimized_significances[nbins] = Z_opt
388
389 print(f"Optimized binning ({nbins} bins): Overall significance = {Z_opt:.3f}")
390
391 opt_plot_filename = (
392 path_plots + f"NN_output_distribution_optimized_{nbins}bins.pdf"
393 )
394 plot_stacked_histograms(
395 stacked_hists=opt_bkg_hists,
396 process_labels=process_labels,
397 signal_hists=[h_signal_opt * 100],
398 signal_labels=signal_labels,
399 output_filename=opt_plot_filename,
400 axis_labels=("Toy NN output", "Events"),
401 normalize=False,
402 log=False,
403 )
404
405 plot_stacked_histograms(
406 stacked_hists=opt_bkg_hists,
407 process_labels=process_labels,
408 signal_hists=[h_signal_opt * 100],
409 signal_labels=signal_labels,
410 output_filename=opt_plot_filename.replace(".pdf", "_log.pdf"),
411 axis_labels=("Toy NN output", "Events"),
412 normalize=False,
413 log=True,
414 )
415 print(f"Optimized binning ({nbins} bins) plot saved as {opt_plot_filename}")
416
417 bias_plot_base = path_plots + f"bias_history_{nbins}bins"
418 plot_bias_history(
419 mean_bias_history,
420 bias_plot_base + ".pdf",
421 epochs=bias_epochs,
422 temp_points=steepness_history,
423 temp_label="Steepness",
424 )
425 plot_bias_history(
426 mean_bias_history,
427 bias_plot_base + "_log.pdf",
428 epochs=bias_epochs,
429 temp_points=steepness_history,
430 temp_label="Steepness",
431 log_scale=True,
432 )
433
434 # Plot the loss
435 loss_plot_name = path_plots + f"history_loss_{nbins}bins.pdf"
436 plot_history(
437 history_data=loss_history,
438 output_filename=loss_plot_name,
439 y_label="Negative significance",
440 x_label="Epoch",
441 boundaries=False,
442 )
443 regularisation_plot_name = path_plots + f"history_penalty_yield{nbins}bins.pdf"
444 plot_history(
445 history_data=penalty_yield_history,
446 output_filename=regularisation_plot_name,
447 y_label="Low bkg. penalty",
448 x_label="Epoch",
449 boundaries=False,
450 )
451 regularisation_plot_name = path_plots + f"history_penalty_unc_{nbins}bins.pdf"
452 plot_history(
453 history_data=penalty_unc_history,
454 output_filename=regularisation_plot_name,
455 y_label="High bkg. unc. penalty",
456 x_label="Epoch",
457 boundaries=False,
458 )
459
460 boundary_plot_name = path_plots + f"history_boundaries_{nbins}bins.pdf"
461 plot_history(
462 history_data=boundary_history,
463 output_filename=boundary_plot_name,
464 y_label="Boundary position",
465 x_label="Epoch",
466 boundaries=True,
467 )
468
469 (
470 B,
471 rel_unc,
472 ) = model.compute_hard_bkg_stats(tensor_data)
473 plot_yield_vs_uncertainty(
474 B,
475 rel_unc,
476 output_filename=path_plots + f"yield_vs_uncertainty_{nbins}bins_sorted.pdf",
477 )
478 plot_yield_vs_uncertainty(
479 B,
480 rel_unc,
481 log=True,
482 output_filename=path_plots
483 + f"yield_vs_uncertainty_{nbins}bins_sorted_log.pdf",
484 )
485
486 plot_significance_comparison(
487 {"": {nb: equidistant_significances[nb] for nb in equidistant_binning_options}},
488 {"": {nb: optimized_significances[nb] for nb in gato_binning_options}},
489 output_filename=path_plots + "significanceComparison.pdf",
490 )
491
492
493if __name__ == "__main__":
494 main()