1D GMM example#
Optimize bin boundaries for a single discriminant using a Gaussian mixture model. The script builds the toy dataset, trains multiple category counts, and compares GATO-derived significances to equidistant baselines.
Run:
python examples/1D_example/run_gmm_example.py --gato-bins 5 10 20 --epochs 300
Key outputs#
Stacked histograms for both equidistant and optimized binning schemes.
Loss, boundary and penalty histories saved under
examples/1D_example/Plots*/.checkpoints/<N>_binsdirectories storing model weights for later inspection.
Source code#
Source of the 1D GMM toy example#
1import argparse
2from gatohep.data_generation import generate_toy_data_1D
3from gatohep.losses import (
4 high_bkg_uncertainty_penalty,
5 low_bkg_penalty,
6)
7from gatohep.models import (
8 gato_gmm_model,
9)
10from gatohep.plotting_utils import (
11 plot_bias_history,
12 plot_gmm_1d,
13 plot_history,
14 plot_significance_comparison,
15 plot_stacked_histograms,
16 plot_yield_vs_uncertainty,
17)
18from gatohep.utils import (
19 LearningRateScheduler,
20 TemperatureScheduler,
21 asymptotic_significance,
22 compute_significance_from_hists,
23 create_hist,
24 df_dict_to_tensors,
25)
26import hist
27import numpy as np
28import tensorflow as tf
29import tensorflow_probability as tfp
30import os
31
32tfd = tfp.distributions
33
34
35# Define a 1D GATO model for the toy example inheriting from gato_gmm_model
36class gato_1D(gato_gmm_model):
37 def __init__(self, n_cats, temperature=1.0, mean_norm="sigmoid"):
38 super().__init__(
39 n_cats=n_cats,
40 dim=1,
41 temperature=temperature,
42 mean_norm=mean_norm,
43 mean_range=(0.0, 1.0),
44 cov_offdiag_damping=0.1,
45 ) # dummy NN output is already in (0,1)
46
47 def call(self, data_dict):
48 """
49 Compute the loss for two signals vs. backgrounds
50 using the generic helpers from the base class.
51 """
52 significances, bkg_yield, bkg_sum_w2 = self.get_differentiable_significance(
53 data_dict,
54 signal_labels=["signal"],
55 return_details=True,
56 )
57 loss = -significances["signal"]
58 return loss, bkg_yield, bkg_sum_w2
59
60
61# main: Generate data, run fixed binning and optimization, then compare.
62def main():
63 parser = argparse.ArgumentParser(description="1-D GATO optimisation on toy data")
64 parser.add_argument(
65 "--epochs",
66 type=int,
67 default=250,
68 help="number of training epochs (default: 250)",
69 )
70
71 parser.add_argument(
72 "--gato-bins",
73 type=int,
74 nargs="+",
75 default=[3, 5, 10],
76 metavar="N",
77 help="List of target bin counts for the GATO run (default: 3,5,10)",
78 )
79
80 parser.add_argument(
81 "--lam-yield",
82 type=float,
83 default=0.0,
84 help=r"lambda for the low-background-yield penalty (default: 0)",
85 )
86
87 parser.add_argument(
88 "--lam-unc",
89 type=float,
90 default=0.0,
91 help=r"lambda for the high-uncertainty penalty (default: 0)",
92 )
93
94 parser.add_argument(
95 "--thr-yield",
96 type=float,
97 default=5.0,
98 help="Threshold (events) below which the low-yield "
99 "penalty turns on (default: 10)",
100 )
101
102 parser.add_argument(
103 "--thr-unc",
104 type=float,
105 default=0.20,
106 help="Relative uncertainty above which the uncertainty "
107 "penalty turns on (default: 0.20)",
108 )
109
110 parser.add_argument(
111 "--n-bkg",
112 type=int,
113 default=300000,
114 help="Total number of background events to generate.",
115 )
116
117 parser.add_argument(
118 "--out",
119 type=str,
120 default="Plots",
121 help='Suffix for the output directory. Default: "Plots"',
122 )
123
124 args = parser.parse_args()
125
126 gato_binning_options = args.gato_bins
127 epochs = args.epochs
128 lam_yield = args.lam_yield
129 lam_unc = args.lam_unc
130 yield_thr = args.thr_yield
131 unc_thr = args.thr_unc
132 n_bkg = args.n_bkg
133
134 # 1. Generate toy data
135 data = generate_toy_data_1D(
136 n_signal=100000,
137 n_bkg=n_bkg,
138 seed=42,
139 )
140
141 # Create fixed histograms, will be rebinned afterwards
142 n_bins = 1500
143 low = 0.0
144 high = 1.0
145 hist_signal = create_hist(
146 data["signal"]["NN_output"],
147 weights=data["signal"]["weight"],
148 bins=n_bins,
149 low=low,
150 high=high,
151 name="Signal",
152 )
153 bkg_processes = [f"bkg{i}" for i in range(1, 6)]
154 bkg_hists = [
155 create_hist(
156 data[proc]["NN_output"],
157 weights=data[proc]["weight"],
158 bins=n_bins,
159 low=low,
160 high=high,
161 name=f"{proc.capitalize()}",
162 )
163 for proc in bkg_processes
164 ]
165
166 # plot the backgrounds:
167 process_labels = [f"Background {i}" for i in range(1, len(bkg_processes) + 1)]
168 signal_labels = ["Signal x 100"]
169
170 # For demonstration, we compare multiple binning schemes.
171 equidistant_binning_options = [2, 5, 10, 20]
172 equidistant_significances = {}
173 optimized_significances = {}
174
175 path_plots = f"examples/1D_example/{args.out}/"
176 os.makedirs(path_plots, exist_ok=True)
177 fixed_plot_filename = path_plots + "toy_data.pdf"
178 plot_stacked_histograms(
179 stacked_hists=[bkg_hist[:: hist.rebin(30)] for bkg_hist in bkg_hists],
180 process_labels=process_labels,
181 signal_hists=[hist_signal[:: hist.rebin(30)] * 100],
182 signal_labels=signal_labels,
183 output_filename=fixed_plot_filename,
184 axis_labels=("Toy discriminant", "Events"),
185 normalize=False,
186 log=False,
187 )
188 plot_stacked_histograms(
189 stacked_hists=[bkg_hist[:: hist.rebin(30)] for bkg_hist in bkg_hists],
190 process_labels=process_labels,
191 signal_hists=[hist_signal[:: hist.rebin(30)] * 100],
192 signal_labels=signal_labels,
193 output_filename=fixed_plot_filename.replace(".pdf", "_log.pdf"),
194 axis_labels=("Toy discriminant", "Events"),
195 normalize=False,
196 log=True,
197 )
198
199 # Fixed binning significance
200 for nbins in equidistant_binning_options:
201 nbins_hist = hist_signal.axes[0].size
202 factor = int(nbins_hist / nbins)
203 hist_signal_rb = hist_signal[:: hist.rebin(factor)]
204 bkg_hists_rb = [h[:: hist.rebin(factor)] for h in bkg_hists]
205
206 Z_equidistant = compute_significance_from_hists(hist_signal_rb, bkg_hists_rb)
207 equidistant_significances[nbins] = Z_equidistant
208 print(
209 f"Fixed binning ({nbins} bins): Overall significance = {Z_equidistant:.3f}"
210 )
211
212 fixed_plot_filename = (
213 path_plots + f"NN_output_distribution_fixed_{nbins}bins.pdf"
214 )
215 plot_stacked_histograms(
216 stacked_hists=bkg_hists_rb,
217 process_labels=process_labels,
218 signal_hists=[hist_signal_rb * 100],
219 signal_labels=signal_labels,
220 output_filename=fixed_plot_filename,
221 axis_labels=("Toy NN output", "Toy events"),
222 normalize=False,
223 log=False,
224 )
225 plot_stacked_histograms(
226 stacked_hists=bkg_hists_rb,
227 process_labels=process_labels,
228 signal_hists=[hist_signal_rb * 100],
229 signal_labels=signal_labels,
230 output_filename=fixed_plot_filename.replace(".pdf", "_log.pdf"),
231 axis_labels=("Toy NN output", "Toy events"),
232 normalize=False,
233 log=True,
234 )
235 print(f"Fixed binning ({nbins} bins) plot saved as {fixed_plot_filename}")
236
237 # GATO part
238 tensor_data = df_dict_to_tensors(data)
239 for nbins in gato_binning_options:
240
241 @tf.function
242 def train_step(
243 model,
244 tensor_data,
245 optimizer,
246 lam_yield=0.0,
247 lam_unc=0.0,
248 threshold_yield=5,
249 rel_threshold_unc=0.2,
250 ):
251 with tf.GradientTape() as tape:
252 # assume your call() now returns (loss, B, B_sumsq)
253 loss, B, B_sumw2 = model.call(tensor_data)
254
255 penalty_yield = low_bkg_penalty(B, threshold=threshold_yield)
256 penalty_unc = high_bkg_uncertainty_penalty(
257 B_sumw2, B, rel_threshold=rel_threshold_unc
258 )
259 total_loss = loss + lam_yield * penalty_yield + lam_unc * penalty_unc
260 grads = tape.gradient(total_loss, model.trainable_variables)
261 optimizer.apply_gradients(zip(grads, model.trainable_variables))
262 return total_loss, loss, penalty_yield, penalty_unc
263
264 # --- Optimization: create a model instance with n_cats = nbins ---
265 model = gato_1D(n_cats=nbins, temperature=1.0)
266
267 optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
268 lr_scheduler = LearningRateScheduler(
269 optimizer,
270 lr_initial=0.1,
271 lr_final=0.001,
272 total_epochs=epochs,
273 mode="cosine",
274 )
275
276 # temperature scheduler
277 temperature_scheduler = TemperatureScheduler(
278 model,
279 t_initial=1.0,
280 t_final=0.05,
281 total_epochs=args.epochs,
282 mode="cosine",
283 )
284
285 loss_history = []
286 penalty_yield_history = []
287 penalty_unc_history = []
288 mean_bias_history = []
289 bias_epochs = []
290 temperature_history = []
291 for epoch in range(epochs):
292 lr_scheduler.update(epoch)
293 temperature_scheduler.update(epoch)
294
295 total_loss, loss, penalty_yield, penalty_unc = train_step(
296 model,
297 tensor_data,
298 optimizer,
299 lam_yield=lam_yield,
300 lam_unc=lam_unc,
301 threshold_yield=yield_thr,
302 rel_threshold_unc=unc_thr,
303 )
304
305 # Save the history
306 loss_history.append(loss.numpy())
307 penalty_yield_history.append(penalty_yield.numpy())
308 penalty_unc_history.append(penalty_unc.numpy())
309 if epoch % 25 == 0 or epoch == epochs - 1:
310 bias_vec = model.get_bias(tensor_data)
311 mean_bias_history.append(float(np.mean(np.abs(bias_vec))))
312 bias_epochs.append(epoch)
313 temperature_history.append(float(model.temperature))
314
315 if epoch % 10 == 0 or epoch == epochs - 1:
316 lr_value = getattr(optimizer, "learning_rate", getattr(optimizer, "lr", None))
317 if hasattr(lr_value, "numpy"):
318 lr_value = float(lr_value.numpy())
319 else:
320 lr_value = float(lr_value)
321 print(
322 f"[n_bins={nbins}] Epoch {epoch}: total_loss = {total_loss.numpy():.3f}, "
323 f"base_loss = {loss.numpy():.3f}, lr = {lr_value:.5f}"
324 )
325 print("Effective boundaries:", model.get_effective_boundaries_1d())
326 # save the trained GATO model
327 checkpoint_dir = os.path.join(
328 path_plots, "checkpoints", f"{nbins}_bins"
329 )
330 os.makedirs(checkpoint_dir, exist_ok=True)
331 model.save(checkpoint_dir)
332 # Rebuild optimized histograms using effective boundaries
333 eff_boundaries = model.get_effective_boundaries_1d()
334 print(f"Optimized boundaries for {nbins} bins: {eff_boundaries}")
335
336 # check bias due to finite temperature in training
337 bias = model.get_bias(tensor_data)
338 print(f"T = {model.temperature:4.2f}; per-bin bias: {bias}")
339
340 opt_bin_edges = np.concatenate(([low], np.array(eff_boundaries), [high]))
341 h_signal_opt = create_hist(
342 data["signal"]["NN_output"],
343 weights=data["signal"]["weight"],
344 bins=opt_bin_edges,
345 name="Signal_opt",
346 )
347 opt_bkg_hists = [
348 create_hist(
349 data[proc]["NN_output"],
350 weights=data[proc]["weight"],
351 bins=opt_bin_edges,
352 name=f"{proc}_opt",
353 )
354 for proc in bkg_processes
355 ]
356
357 # Compute significance from these optimized histograms.
358 Z_opt = compute_significance_from_hists(h_signal_opt, opt_bkg_hists)
359 optimized_significances[nbins] = Z_opt
360
361 print(f"Optimized binning ({nbins} bins): Overall significance = {Z_opt:.3f}")
362
363 bias_plot_base = path_plots + f"bias_history_{nbins}bins"
364 plot_bias_history(
365 mean_bias_history,
366 bias_plot_base + ".pdf",
367 epochs=bias_epochs,
368 temp_points=temperature_history,
369 temp_label="Temperature",
370 )
371 plot_bias_history(
372 mean_bias_history,
373 bias_plot_base + "_log.pdf",
374 epochs=bias_epochs,
375 temp_points=temperature_history,
376 temp_label="Temperature",
377 log_scale=True,
378 )
379
380 opt_plot_filename = (
381 path_plots + f"NN_output_distribution_optimized_{nbins}bins.pdf"
382 )
383 plot_stacked_histograms(
384 stacked_hists=opt_bkg_hists,
385 process_labels=process_labels,
386 signal_hists=[h_signal_opt * 100],
387 signal_labels=signal_labels,
388 output_filename=opt_plot_filename,
389 axis_labels=("Toy NN output", "Events"),
390 normalize=False,
391 log=False,
392 )
393
394 plot_stacked_histograms(
395 stacked_hists=opt_bkg_hists,
396 process_labels=process_labels,
397 signal_hists=[h_signal_opt * 100],
398 signal_labels=signal_labels,
399 output_filename=opt_plot_filename.replace(".pdf", "_log.pdf"),
400 axis_labels=("Toy NN output", "Events"),
401 normalize=False,
402 log=True,
403 )
404 print(f"Optimized binning ({nbins} bins) plot saved as {opt_plot_filename}")
405
406 # Plot the loss
407 loss_plot_name = path_plots + f"history_loss_{nbins}bins.pdf"
408 plot_history(
409 history_data=loss_history,
410 output_filename=loss_plot_name,
411 y_label="Negative significance",
412 x_label="Epoch",
413 boundaries=False,
414 )
415 regularisation_plot_name = path_plots + f"history_penalty_yield{nbins}bins.pdf"
416 plot_history(
417 history_data=penalty_yield_history,
418 output_filename=regularisation_plot_name,
419 y_label="Low bkg. penalty",
420 x_label="Epoch",
421 boundaries=False,
422 )
423 regularisation_plot_name = path_plots + f"history_penalty_unc_{nbins}bins.pdf"
424 plot_history(
425 history_data=penalty_unc_history,
426 output_filename=regularisation_plot_name,
427 y_label="High bkg. unc. penalty",
428 x_label="Epoch",
429 boundaries=False,
430 )
431
432 B_sorted, rel_unc_sorted, _ = model.compute_hard_bkg_stats(tensor_data)
433 plot_yield_vs_uncertainty(
434 B_sorted,
435 rel_unc_sorted,
436 output_filename=path_plots + f"yield_vs_uncertainty_{nbins}bins_sorted.pdf",
437 )
438 plot_yield_vs_uncertainty(
439 B_sorted,
440 rel_unc_sorted,
441 log=True,
442 output_filename=path_plots
443 + f"yield_vs_uncertainty_{nbins}bins_sorted_log.pdf",
444 )
445 # plot the learned GMM in 1D
446 plot_gmm_1d(
447 model,
448 output_filename=os.path.join(path_plots, f"gmm_components_{nbins}bins.pdf"),
449 x_range=(low, high),
450 n_points=1000,
451 )
452
453 plot_significance_comparison(
454 {"": {nb: equidistant_significances[nb] for nb in equidistant_binning_options}},
455 {"": {nb: optimized_significances[nb] for nb in gato_binning_options}},
456 output_filename=path_plots + "significanceComparison.pdf",
457 )
458
459
460if __name__ == "__main__":
461 main()