Bump hunt example#
This workflow closely follows the three-class softmax example, but here we do the inference on an unrelated event variable (“mass”), mirroring a bump hunt workflow as e.g. in Higgs-to-γγ analyses. Hence, events in a category are not all in one bin (as it was in the other workflows), but rather spread out over the full mass range and only partly contribute to the significance obtained in the small signal window.
Generate two resonant signals and five continuum backgrounds (building on the three-class softmax example).
Assign a diphoton mass to every event (Gaussian for signal, exponentials for the continuum).
Technically, we could perform the categorization optimization purely on the events in a small signal window around 125 GeV, but practically, often we suffer from low background statistics.
Therefore, we use the full power of the continuum background simulation by including all events in the gradient calculations, but reweighting the yield to match the expectation in the signal window (125 ± σ) during training. For this, we fit the background with exponentials in each category.
Run it with for instance:
python examples/bumphunt_example/run_example.py \
--epochs 400 \
--gato-bins 5 8 \
--out PlotsBumpHunt
Outputs land under examples/bumphunt_example/<out>/ and contain:
Inclusive diphoton-mass spectra before categorisation (linear/log).
Per-category diphoton spectra for all signals/backgrounds.
Loss, penalty, and bias histories with temperature annotations.
Boundary snapshots + GIFs showing the 2-D category evolution.
Yield vs. statistical-uncertainty bar plots per category.
Saved checkpoints for each trained configuration.
Source#
1import argparse
2import os
3
4import numpy as np
5import tensorflow as tf
6
7from gatohep.losses import high_bkg_uncertainty_penalty, low_bkg_penalty
8from gatohep.models import gato_gmm_model
9from gatohep.plotting_utils import (
10 assign_bins_and_order,
11 make_gif,
12 plot_bias_history,
13 plot_bin_boundaries_2D,
14 plot_category_mass_spectra,
15 plot_history,
16 plot_inclusive_mass,
17 plot_significance_comparison,
18 plot_yield_vs_uncertainty,
19)
20from gatohep.utils import (
21 LearningRateScheduler,
22 TemperatureScheduler,
23 asymptotic_significance,
24 build_category_mass_maps,
25 compute_mass_reweight_factors,
26 convert_mass_data_to_tensors,
27 generate_resonance_toy_data,
28 slice_to_2d_features,
29)
30
31
32class DiphotonSoftmax(gato_gmm_model):
33 def __init__(self, n_cats, temperature=0.5, mass_sigma=1.5):
34 super().__init__(
35 n_cats=n_cats,
36 dim=2,
37 temperature=temperature,
38 mean_norm="softmax",
39 cov_offdiag_damping=0.1,
40 name="gato_diphoton",
41 )
42 self.mass_center = tf.constant(125.0, dtype=tf.float32)
43 self.mass_sigma = tf.constant(float(mass_sigma), dtype=tf.float32)
44 self.mass_sig_low = self.mass_center - self.mass_sigma
45 self.mass_sig_high = self.mass_center + self.mass_sigma
46
47 def call(self, data_dict, reweight=None, reweight_processes=None):
48 masked = {}
49 for proc, tensors in data_dict.items():
50 weights = tensors["weight"]
51 if proc in ("signal1", "signal2"):
52 masses = tensors["mass"]
53 window_mask = tf.cast(
54 tf.logical_and(
55 masses >= self.mass_sig_low, masses <= self.mass_sig_high
56 ),
57 tf.float32,
58 )
59 weights = weights * window_mask
60 masked[proc] = {
61 "NN_output": tensors["NN_output"],
62 "weight": weights,
63 }
64
65 significances, bkg_yield, bkg_sum_w2 = self.get_differentiable_significance(
66 masked,
67 signal_labels=["signal1", "signal2"],
68 background_reweight=reweight,
69 reweight_processes=reweight_processes,
70 return_details=True,
71 )
72 z1 = significances["signal1"]
73 z2 = significances["signal2"]
74 loss = -tf.sqrt(z1 * z2)
75 return loss, bkg_yield, bkg_sum_w2, z1, z2
76
77
78def compute_significances_from_assignments(
79 assignments, data_dict, n_bins, mass_low, mass_high
80):
81 """
82 Sum signal/background yields per bin using a provided assignment map.
83
84 Parameters
85 ----------
86 assignments : dict[str, np.ndarray]
87 Hard bin indices per process (negative entries ignored).
88 data_dict : Mapping[str, pandas.DataFrame]
89 Event tables containing ``mass`` and ``weight`` columns.
90 n_bins : int
91 Number of categories / bins.
92 mass_low, mass_high : float
93 Higgs-window boundaries.
94
95 Returns
96 -------
97 tuple[float, float]
98 Significances for ``signal1`` and ``signal2``.
99 """
100 s1 = np.zeros(n_bins, dtype=np.float64)
101 s2 = np.zeros_like(s1)
102 bkg = np.zeros_like(s1)
103
104 for proc, assign in assignments.items():
105 if assign.size == 0:
106 continue
107 df = data_dict[proc]
108 masses = df["mass"].values
109 weights = df["weight"].values
110 mask = (
111 (assign >= 0)
112 & (masses >= mass_low)
113 & (masses <= mass_high)
114 )
115 if not np.any(mask):
116 continue
117 bins = assign[mask]
118 w = weights[mask]
119 accum = np.zeros(n_bins, dtype=np.float64)
120 np.add.at(accum, bins, w)
121 if proc == "signal1":
122 s1 += accum
123 elif proc == "signal2":
124 s2 += accum
125 else:
126 bkg += accum
127
128 s1_tf = tf.constant(s1, dtype=tf.float32)
129 s2_tf = tf.constant(s2, dtype=tf.float32)
130 bkg_tf = tf.constant(bkg, dtype=tf.float32)
131
132 z1_bins = asymptotic_significance(s1_tf, bkg_tf + s2_tf)
133 z2_bins = asymptotic_significance(s2_tf, bkg_tf + s1_tf)
134 z1 = float(tf.sqrt(tf.reduce_sum(z1_bins**2)))
135 z2 = float(tf.sqrt(tf.reduce_sum(z2_bins**2)))
136 return z1, z2
137
138
139def build_argmax_assignments(data_dict, nbins, sig_index):
140 """
141 Produce equidistant bin indices based on a softmax component.
142
143 Only events whose argmax equals ``sig_index`` receive a valid bin,
144 reproducing the baseline used in the three-class example.
145 """
146 edges = np.linspace(0.33, 1.0, nbins + 1, dtype=np.float32)
147 assignments = {}
148 for proc, df in data_dict.items():
149 if df.empty:
150 assignments[proc] = np.array([], dtype=np.int32)
151 continue
152 outputs = np.stack(df["NN_output"].values)
153 argmax = np.argmax(outputs, axis=1)
154 values = outputs[:, sig_index]
155 bins = np.clip(np.digitize(values, edges, right=False) - 1, 0, nbins - 1)
156 valid = argmax == sig_index
157 assign = np.where(valid, bins, -1).astype(np.int32)
158 assignments[proc] = assign
159 return assignments
160
161
162def main():
163 parser = argparse.ArgumentParser(
164 description="Diphoton bump-hunt optimisation with GATO."
165 )
166 parser.add_argument("--epochs", type=int, default=200)
167 parser.add_argument("--gato-bins", nargs="+", type=int, default=[3, 5])
168 parser.add_argument("--lam-yield", type=float, default=0.0)
169 parser.add_argument("--lam-unc", type=float, default=0.0)
170 parser.add_argument("--thr-yield", type=float, default=10)
171 parser.add_argument("--thr-unc", type=float, default=0.1)
172 parser.add_argument("--n-bkg", type=int, default=1_000_000)
173 parser.add_argument("--n-signal", type=int, default=100_000)
174 parser.add_argument("--rewt-interval", type=int, default=50)
175 parser.add_argument("--mass-sigma", type=float, default=1.5)
176 parser.add_argument("--out", type=str, default="PlotsDiphotonBumpHunt")
177 args = parser.parse_args()
178
179 path_plots = os.path.join("examples", "bumphunt_example", args.out)
180 os.makedirs(path_plots, exist_ok=True)
181
182 data_full = generate_resonance_toy_data(
183 n_signal1=args.n_signal,
184 n_signal2=args.n_signal,
185 n_bkg=args.n_bkg,
186 mass_sigma=args.mass_sigma,
187 background_slopes=(0.05, 0.04, 0.035, 0.03, 0.025),
188 )
189 data_2d = slice_to_2d_features(data_full)
190 tensor_data = convert_mass_data_to_tensors(data_2d)
191
192 sig_low = 125.0 - args.mass_sigma
193 sig_high = 125.0 + args.mass_sigma
194 plot_inclusive_mass(data_2d, path_plots, sig_scales=(50, 250))
195
196 baseline_bins = [2, 5, 10]
197 baseline_results = {"signal1": {}, "signal2": {}}
198 gato_results = {"signal1": {}, "signal2": {}}
199
200 for nbins in baseline_bins:
201 for sig_idx, sig_name in enumerate(("signal1", "signal2")):
202 assignments = build_argmax_assignments(data_2d, nbins, sig_idx)
203 z1, z2 = compute_significances_from_assignments(
204 assignments,
205 data_2d,
206 nbins,
207 sig_low,
208 sig_high,
209 )
210 baseline_results[sig_name][nbins] = z1 if sig_idx == 0 else z2
211
212 for n_cats in args.gato_bins:
213 print(f"\n--- Optimising {n_cats} bins ---")
214 model = DiphotonSoftmax(
215 n_cats=n_cats, temperature=1.0, mass_sigma=args.mass_sigma
216 )
217 optimizer = tf.keras.optimizers.RMSprop(0.05)
218 lr_scheduler = LearningRateScheduler(
219 optimizer,
220 lr_initial=0.05,
221 lr_final=0.001,
222 total_epochs=args.epochs,
223 mode="cosine",
224 )
225 temp_scheduler = TemperatureScheduler(
226 model,
227 t_initial=1.0,
228 t_final=0.1,
229 total_epochs=args.epochs,
230 mode="cosine",
231 )
232
233 @tf.function
234 def train_step(tdata, reweight_tensor, lamY, lamU, thrY, thrU):
235 with tf.GradientTape() as tape:
236 loss, B_sig, B_sig_w2, z1, z2 = model.call(tdata, reweight_tensor)
237 penalty_y = low_bkg_penalty(B_sig, threshold=thrY)
238 penalty_u = high_bkg_uncertainty_penalty(
239 B_sig_w2, B_sig, rel_threshold=thrU
240 )
241 total = loss + lamY * penalty_y + lamU * penalty_u
242 grads = tape.gradient(total, model.trainable_variables)
243 optimizer.apply_gradients(zip(grads, model.trainable_variables))
244 return total, loss, penalty_y, penalty_u, z1, z2, B_sig
245
246 reweight = tf.ones(n_cats, dtype=tf.float32)
247 loss_history = []
248 penalty_y_hist = []
249 penalty_u_hist = []
250 continuum_history = []
251 bias_history = []
252 bias_epochs = []
253 temp_history = []
254 path_bins = os.path.join(path_plots, f"gato_{n_cats}bins")
255 os.makedirs(path_bins, exist_ok=True)
256 frames_dir = os.path.join(path_bins, "boundary_frames")
257 os.makedirs(frames_dir, exist_ok=True)
258 boundary_frames = []
259
260 for epoch in range(args.epochs):
261 if epoch % max(1, args.rewt_interval) == 0:
262 factors = compute_mass_reweight_factors(
263 model,
264 data_2d,
265 signal_labels=["signal1", "signal2"],
266 mass_sig_low=sig_low,
267 mass_sig_high=sig_high,
268 )
269 reweight = tf.constant(factors, dtype=tf.float32)
270 print(f"Updated reweight factors: {factors}")
271
272 _, loss, penY, penU, z1, z2, B_bins = train_step(
273 tensor_data,
274 reweight,
275 args.lam_yield,
276 args.lam_unc,
277 args.thr_yield,
278 args.thr_unc,
279 )
280 lr_scheduler.update(epoch)
281 temp_scheduler.update(epoch)
282
283 loss_history.append(float(loss.numpy()))
284 penalty_y_hist.append(float(penY.numpy()))
285 penalty_u_hist.append(float(penU.numpy()))
286 reweight_np = reweight.numpy()
287 B_np = B_bins.numpy()
288 continuum_history.append(
289 B_np / np.maximum(reweight_np, 1e-6)
290 )
291
292 if epoch % 10 == 0 or epoch == args.epochs - 1:
293 lr_value = getattr(optimizer, "learning_rate", getattr(optimizer, "lr", None))
294 lr_value = float(lr_value.numpy()) if hasattr(lr_value, "numpy") else float(lr_value)
295 temperature = model.temperature
296 temp_history.append(temperature)
297 bias_input = {
298 p: {
299 "NN_output": tensor_data[p]["NN_output"],
300 "weight": tensor_data[p]["weight"],
301 }
302 for p in tensor_data
303 }
304 bias_vec = model.get_bias(bias_input)
305 bias_history.append(float(np.mean(np.abs(bias_vec))))
306 bias_epochs.append(epoch)
307 boundary_fname = os.path.join(frames_dir, f"boundary_{epoch:04d}.png")
308 plot_bin_boundaries_2D(
309 model,
310 list(range(n_cats)),
311 boundary_fname,
312 resolution=600,
313 annotation=f"Epoch {epoch}",
314 )
315 boundary_frames.append(boundary_fname)
316 print(
317 f"[{epoch:04d}] loss={loss.numpy():.4f} "
318 f"Z1={z1.numpy():.3f} Z2={z2.numpy():.3f} lr={lr_value:.5f}"
319 )
320
321 ckpt_dir = os.path.join(path_plots, "checkpoints", f"{n_cats}_bins")
322 os.makedirs(ckpt_dir, exist_ok=True)
323 model.save(ckpt_dir)
324
325 loss_eval = model.call(tensor_data, reweight)
326 _, _, _, z1_final, z2_final = loss_eval
327 print(
328 f"Final significances for {n_cats} bins: "
329 f"Z(signal1)={float(z1_final.numpy()):.3f}, "
330 f"Z(signal2)={float(z2_final.numpy()):.3f}"
331 )
332
333 plot_history(
334 np.array(loss_history),
335 os.path.join(path_bins, f"loss_{n_cats}.pdf"),
336 y_label="Geometric mean significance",
337 x_label="Epoch",
338 )
339 plot_history(
340 np.array(penalty_y_hist),
341 os.path.join(path_bins, f"penalty_yield_{n_cats}.pdf"),
342 y_label="Low background penalty",
343 x_label="Epoch",
344 )
345 plot_history(
346 np.array(penalty_u_hist),
347 os.path.join(path_bins, f"penalty_unc_{n_cats}.pdf"),
348 y_label="High-uncertainty penalty",
349 x_label="Epoch",
350 )
351 plot_history(
352 np.array(continuum_history),
353 os.path.join(path_bins, f"continuum_background_{n_cats}.pdf"),
354 y_label="Continuum background (100-180 GeV)",
355 x_label="Epoch",
356 boundaries=True,
357 boundary_labels=[f"Cat. {i}" for i in range(n_cats)],
358 )
359 plot_history(
360 np.array(continuum_history),
361 os.path.join(path_bins, f"continuum_background_{n_cats}_log.pdf"),
362 y_label="Continuum background (100-180 GeV)",
363 x_label="Epoch",
364 boundaries=True,
365 log_scale=True,
366 boundary_labels=[f"Cat. {i}" for i in range(n_cats)],
367 )
368 plot_bias_history(
369 bias_history,
370 os.path.join(path_bins, f"bias_history_{n_cats}.pdf"),
371 epochs=bias_epochs,
372 temp_points=temp_history,
373 temp_label="Temperature",
374 )
375
376 raw_assign = model.get_bin_indices(
377 {p: {"NN_output": tensor_data[p]["NN_output"]} for p in tensor_data}
378 )
379 raw_assign_np = {k: v.numpy() for k, v in raw_assign.items()}
380
381 assign_dict, order, _, inv = assign_bins_and_order(model, data_2d, reduce=False)
382 assign_np = {k: np.asarray(v) for k, v in assign_dict.items()}
383
384 z1_opt, z2_opt = compute_significances_from_assignments(
385 raw_assign_np,
386 data_2d,
387 n_cats,
388 sig_low,
389 sig_high,
390 )
391 gato_results["signal1"][n_cats] = z1_opt
392 gato_results["signal2"][n_cats] = z2_opt
393 per_cat_hists = build_category_mass_maps(raw_assign_np, data_2d, n_cats)
394 plot_category_mass_spectra(
395 per_cat_hists,
396 os.path.join(path_bins, "mass_spectra"),
397 sig_scales=(2, 10),
398 )
399
400 plot_bin_boundaries_2D(
401 model,
402 order,
403 os.path.join(path_bins, f"bin_boundaries_{n_cats}_bins.pdf"),
404 )
405
406 if boundary_frames:
407 gif_path = os.path.join(path_bins, f"boundary_evolution_{n_cats}.gif")
408 make_gif(boundary_frames, gif_path, interval=500)
409 B_sorted, rel_unc, _ = model.compute_hard_bkg_stats(
410 {p: {"NN_output": tensor_data[p]["NN_output"], "weight": tensor_data[p]["weight"]} for p in tensor_data},
411 signal_labels=["signal1", "signal2"],
412 )
413 plot_yield_vs_uncertainty(
414 B_sorted,
415 rel_unc,
416 output_filename=os.path.join(path_bins, f"yield_unc_{n_cats}.pdf"),
417 )
418 plot_yield_vs_uncertainty(
419 B_sorted,
420 rel_unc,
421 output_filename=os.path.join(path_bins, f"yield_unc_{n_cats}_log.pdf"),
422 log=True,
423 )
424
425 remapped_baseline = {
426 sig: {2 * n + 1: baseline_results[sig][n] for n in baseline_results[sig]}
427 for sig in baseline_results
428 }
429 plot_significance_comparison(
430 remapped_baseline,
431 gato_results,
432 os.path.join(path_plots, "significance_comparison.pdf"),
433 )
434
435if __name__ == "__main__":
436 main()