Bump hunt example#

This workflow closely follows the three-class softmax example, but here we do the inference on an unrelated event variable (“mass”), mirroring a bump hunt workflow as e.g. in Higgs-to-γγ analyses. Hence, events in a category are not all in one bin (as it was in the other workflows), but rather spread out over the full mass range and only partly contribute to the significance obtained in the small signal window.

  • Generate two resonant signals and five continuum backgrounds (building on the three-class softmax example).

  • Assign a diphoton mass to every event (Gaussian for signal, exponentials for the continuum).

  • Technically, we could perform the categorization optimization purely on the events in a small signal window around 125 GeV, but practically, often we suffer from low background statistics.

  • Therefore, we use the full power of the continuum background simulation by including all events in the gradient calculations, but reweighting the yield to match the expectation in the signal window (125 ± σ) during training. For this, we fit the background with exponentials in each category.

Run it with for instance:

python examples/bumphunt_example/run_example.py \
    --epochs 400 \
    --gato-bins 5 8 \
    --out PlotsBumpHunt

Outputs land under examples/bumphunt_example/<out>/ and contain:

  • Inclusive diphoton-mass spectra before categorisation (linear/log).

  • Per-category diphoton spectra for all signals/backgrounds.

  • Loss, penalty, and bias histories with temperature annotations.

  • Boundary snapshots + GIFs showing the 2-D category evolution.

  • Yield vs. statistical-uncertainty bar plots per category.

  • Saved checkpoints for each trained configuration.

Source#

Diphoton bump-hunt optimisation script#
  1import argparse
  2import os
  3
  4import numpy as np
  5import tensorflow as tf
  6
  7from gatohep.losses import high_bkg_uncertainty_penalty, low_bkg_penalty
  8from gatohep.models import gato_gmm_model
  9from gatohep.plotting_utils import (
 10    assign_bins_and_order,
 11    make_gif,
 12    plot_bias_history,
 13    plot_bin_boundaries_2D,
 14    plot_category_mass_spectra,
 15    plot_history,
 16    plot_inclusive_mass,
 17    plot_significance_comparison,
 18    plot_yield_vs_uncertainty,
 19)
 20from gatohep.utils import (
 21    LearningRateScheduler,
 22    TemperatureScheduler,
 23    asymptotic_significance,
 24    build_category_mass_maps,
 25    compute_mass_reweight_factors,
 26    convert_mass_data_to_tensors,
 27    generate_resonance_toy_data,
 28    slice_to_2d_features,
 29)
 30
 31
 32class DiphotonSoftmax(gato_gmm_model):
 33    def __init__(self, n_cats, temperature=0.5, mass_sigma=1.5):
 34        super().__init__(
 35            n_cats=n_cats,
 36            dim=2,
 37            temperature=temperature,
 38            mean_norm="softmax",
 39            cov_offdiag_damping=0.1,
 40            name="gato_diphoton",
 41        )
 42        self.mass_center = tf.constant(125.0, dtype=tf.float32)
 43        self.mass_sigma = tf.constant(float(mass_sigma), dtype=tf.float32)
 44        self.mass_sig_low = self.mass_center - self.mass_sigma
 45        self.mass_sig_high = self.mass_center + self.mass_sigma
 46
 47    def call(self, data_dict, reweight=None, reweight_processes=None):
 48        masked = {}
 49        for proc, tensors in data_dict.items():
 50            weights = tensors["weight"]
 51            if proc in ("signal1", "signal2"):
 52                masses = tensors["mass"]
 53                window_mask = tf.cast(
 54                    tf.logical_and(
 55                        masses >= self.mass_sig_low, masses <= self.mass_sig_high
 56                    ),
 57                    tf.float32,
 58                )
 59                weights = weights * window_mask
 60            masked[proc] = {
 61                "NN_output": tensors["NN_output"],
 62                "weight": weights,
 63            }
 64
 65        significances, bkg_yield, bkg_sum_w2 = self.get_differentiable_significance(
 66            masked,
 67            signal_labels=["signal1", "signal2"],
 68            background_reweight=reweight,
 69            reweight_processes=reweight_processes,
 70            return_details=True,
 71        )
 72        z1 = significances["signal1"]
 73        z2 = significances["signal2"]
 74        loss = -tf.sqrt(z1 * z2)
 75        return loss, bkg_yield, bkg_sum_w2, z1, z2
 76
 77
 78def compute_significances_from_assignments(
 79    assignments, data_dict, n_bins, mass_low, mass_high
 80):
 81    """
 82    Sum signal/background yields per bin using a provided assignment map.
 83
 84    Parameters
 85    ----------
 86    assignments : dict[str, np.ndarray]
 87        Hard bin indices per process (negative entries ignored).
 88    data_dict : Mapping[str, pandas.DataFrame]
 89        Event tables containing ``mass`` and ``weight`` columns.
 90    n_bins : int
 91        Number of categories / bins.
 92    mass_low, mass_high : float
 93        Higgs-window boundaries.
 94
 95    Returns
 96    -------
 97    tuple[float, float]
 98        Significances for ``signal1`` and ``signal2``.
 99    """
100    s1 = np.zeros(n_bins, dtype=np.float64)
101    s2 = np.zeros_like(s1)
102    bkg = np.zeros_like(s1)
103
104    for proc, assign in assignments.items():
105        if assign.size == 0:
106            continue
107        df = data_dict[proc]
108        masses = df["mass"].values
109        weights = df["weight"].values
110        mask = (
111            (assign >= 0)
112            & (masses >= mass_low)
113            & (masses <= mass_high)
114        )
115        if not np.any(mask):
116            continue
117        bins = assign[mask]
118        w = weights[mask]
119        accum = np.zeros(n_bins, dtype=np.float64)
120        np.add.at(accum, bins, w)
121        if proc == "signal1":
122            s1 += accum
123        elif proc == "signal2":
124            s2 += accum
125        else:
126            bkg += accum
127
128    s1_tf = tf.constant(s1, dtype=tf.float32)
129    s2_tf = tf.constant(s2, dtype=tf.float32)
130    bkg_tf = tf.constant(bkg, dtype=tf.float32)
131
132    z1_bins = asymptotic_significance(s1_tf, bkg_tf + s2_tf)
133    z2_bins = asymptotic_significance(s2_tf, bkg_tf + s1_tf)
134    z1 = float(tf.sqrt(tf.reduce_sum(z1_bins**2)))
135    z2 = float(tf.sqrt(tf.reduce_sum(z2_bins**2)))
136    return z1, z2
137
138
139def build_argmax_assignments(data_dict, nbins, sig_index):
140    """
141    Produce equidistant bin indices based on a softmax component.
142
143    Only events whose argmax equals ``sig_index`` receive a valid bin,
144    reproducing the baseline used in the three-class example.
145    """
146    edges = np.linspace(0.33, 1.0, nbins + 1, dtype=np.float32)
147    assignments = {}
148    for proc, df in data_dict.items():
149        if df.empty:
150            assignments[proc] = np.array([], dtype=np.int32)
151            continue
152        outputs = np.stack(df["NN_output"].values)
153        argmax = np.argmax(outputs, axis=1)
154        values = outputs[:, sig_index]
155        bins = np.clip(np.digitize(values, edges, right=False) - 1, 0, nbins - 1)
156        valid = argmax == sig_index
157        assign = np.where(valid, bins, -1).astype(np.int32)
158        assignments[proc] = assign
159    return assignments
160
161
162def main():
163    parser = argparse.ArgumentParser(
164        description="Diphoton bump-hunt optimisation with GATO."
165    )
166    parser.add_argument("--epochs", type=int, default=200)
167    parser.add_argument("--gato-bins", nargs="+", type=int, default=[3, 5])
168    parser.add_argument("--lam-yield", type=float, default=0.0)
169    parser.add_argument("--lam-unc", type=float, default=0.0)
170    parser.add_argument("--thr-yield", type=float, default=10)
171    parser.add_argument("--thr-unc", type=float, default=0.1)
172    parser.add_argument("--n-bkg", type=int, default=1_000_000)
173    parser.add_argument("--n-signal", type=int, default=100_000)
174    parser.add_argument("--rewt-interval", type=int, default=50)
175    parser.add_argument("--mass-sigma", type=float, default=1.5)
176    parser.add_argument("--out", type=str, default="PlotsDiphotonBumpHunt")
177    args = parser.parse_args()
178
179    path_plots = os.path.join("examples", "bumphunt_example", args.out)
180    os.makedirs(path_plots, exist_ok=True)
181
182    data_full = generate_resonance_toy_data(
183        n_signal1=args.n_signal,
184        n_signal2=args.n_signal,
185        n_bkg=args.n_bkg,
186        mass_sigma=args.mass_sigma,
187        background_slopes=(0.05, 0.04, 0.035, 0.03, 0.025),
188    )
189    data_2d = slice_to_2d_features(data_full)
190    tensor_data = convert_mass_data_to_tensors(data_2d)
191
192    sig_low = 125.0 - args.mass_sigma
193    sig_high = 125.0 + args.mass_sigma
194    plot_inclusive_mass(data_2d, path_plots, sig_scales=(50, 250))
195
196    baseline_bins = [2, 5, 10]
197    baseline_results = {"signal1": {}, "signal2": {}}
198    gato_results = {"signal1": {}, "signal2": {}}
199
200    for nbins in baseline_bins:
201        for sig_idx, sig_name in enumerate(("signal1", "signal2")):
202            assignments = build_argmax_assignments(data_2d, nbins, sig_idx)
203            z1, z2 = compute_significances_from_assignments(
204                assignments,
205                data_2d,
206                nbins,
207                sig_low,
208                sig_high,
209            )
210            baseline_results[sig_name][nbins] = z1 if sig_idx == 0 else z2
211
212    for n_cats in args.gato_bins:
213        print(f"\n--- Optimising {n_cats} bins ---")
214        model = DiphotonSoftmax(
215            n_cats=n_cats, temperature=1.0, mass_sigma=args.mass_sigma
216        )
217        optimizer = tf.keras.optimizers.RMSprop(0.05)
218        lr_scheduler = LearningRateScheduler(
219            optimizer,
220            lr_initial=0.05,
221            lr_final=0.001,
222            total_epochs=args.epochs,
223            mode="cosine",
224        )
225        temp_scheduler = TemperatureScheduler(
226            model,
227            t_initial=1.0,
228            t_final=0.1,
229            total_epochs=args.epochs,
230            mode="cosine",
231        )
232
233        @tf.function
234        def train_step(tdata, reweight_tensor, lamY, lamU, thrY, thrU):
235            with tf.GradientTape() as tape:
236                loss, B_sig, B_sig_w2, z1, z2 = model.call(tdata, reweight_tensor)
237                penalty_y = low_bkg_penalty(B_sig, threshold=thrY)
238                penalty_u = high_bkg_uncertainty_penalty(
239                    B_sig_w2, B_sig, rel_threshold=thrU
240                )
241                total = loss + lamY * penalty_y + lamU * penalty_u
242            grads = tape.gradient(total, model.trainable_variables)
243            optimizer.apply_gradients(zip(grads, model.trainable_variables))
244            return total, loss, penalty_y, penalty_u, z1, z2, B_sig
245
246        reweight = tf.ones(n_cats, dtype=tf.float32)
247        loss_history = []
248        penalty_y_hist = []
249        penalty_u_hist = []
250        continuum_history = []
251        bias_history = []
252        bias_epochs = []
253        temp_history = []
254        path_bins = os.path.join(path_plots, f"gato_{n_cats}bins")
255        os.makedirs(path_bins, exist_ok=True)
256        frames_dir = os.path.join(path_bins, "boundary_frames")
257        os.makedirs(frames_dir, exist_ok=True)
258        boundary_frames = []
259
260        for epoch in range(args.epochs):
261            if epoch % max(1, args.rewt_interval) == 0:
262                factors = compute_mass_reweight_factors(
263                    model,
264                    data_2d,
265                    signal_labels=["signal1", "signal2"],
266                    mass_sig_low=sig_low,
267                    mass_sig_high=sig_high,
268                )
269                reweight = tf.constant(factors, dtype=tf.float32)
270                print(f"Updated reweight factors: {factors}")
271
272            _, loss, penY, penU, z1, z2, B_bins = train_step(
273                tensor_data,
274                reweight,
275                args.lam_yield,
276                args.lam_unc,
277                args.thr_yield,
278                args.thr_unc,
279            )
280            lr_scheduler.update(epoch)
281            temp_scheduler.update(epoch)
282
283            loss_history.append(float(loss.numpy()))
284            penalty_y_hist.append(float(penY.numpy()))
285            penalty_u_hist.append(float(penU.numpy()))
286            reweight_np = reweight.numpy()
287            B_np = B_bins.numpy()
288            continuum_history.append(
289                B_np / np.maximum(reweight_np, 1e-6)
290            )
291
292            if epoch % 10 == 0 or epoch == args.epochs - 1:
293                lr_value = getattr(optimizer, "learning_rate", getattr(optimizer, "lr", None))
294                lr_value = float(lr_value.numpy()) if hasattr(lr_value, "numpy") else float(lr_value)
295                temperature = model.temperature
296                temp_history.append(temperature)
297                bias_input = {
298                    p: {
299                        "NN_output": tensor_data[p]["NN_output"],
300                        "weight": tensor_data[p]["weight"],
301                    }
302                    for p in tensor_data
303                }
304                bias_vec = model.get_bias(bias_input)
305                bias_history.append(float(np.mean(np.abs(bias_vec))))
306                bias_epochs.append(epoch)
307                boundary_fname = os.path.join(frames_dir, f"boundary_{epoch:04d}.png")
308                plot_bin_boundaries_2D(
309                    model,
310                    list(range(n_cats)),
311                    boundary_fname,
312                    resolution=600,
313                    annotation=f"Epoch {epoch}",
314                )
315                boundary_frames.append(boundary_fname)
316                print(
317                    f"[{epoch:04d}] loss={loss.numpy():.4f} "
318                    f"Z1={z1.numpy():.3f} Z2={z2.numpy():.3f} lr={lr_value:.5f}"
319                )
320
321        ckpt_dir = os.path.join(path_plots, "checkpoints", f"{n_cats}_bins")
322        os.makedirs(ckpt_dir, exist_ok=True)
323        model.save(ckpt_dir)
324
325        loss_eval = model.call(tensor_data, reweight)
326        _, _, _, z1_final, z2_final = loss_eval
327        print(
328            f"Final significances for {n_cats} bins: "
329            f"Z(signal1)={float(z1_final.numpy()):.3f}, "
330            f"Z(signal2)={float(z2_final.numpy()):.3f}"
331        )
332
333        plot_history(
334            np.array(loss_history),
335            os.path.join(path_bins, f"loss_{n_cats}.pdf"),
336            y_label="Geometric mean significance",
337            x_label="Epoch",
338        )
339        plot_history(
340            np.array(penalty_y_hist),
341            os.path.join(path_bins, f"penalty_yield_{n_cats}.pdf"),
342            y_label="Low background penalty",
343            x_label="Epoch",
344        )
345        plot_history(
346            np.array(penalty_u_hist),
347            os.path.join(path_bins, f"penalty_unc_{n_cats}.pdf"),
348            y_label="High-uncertainty penalty",
349            x_label="Epoch",
350        )
351        plot_history(
352            np.array(continuum_history),
353            os.path.join(path_bins, f"continuum_background_{n_cats}.pdf"),
354            y_label="Continuum background (100-180 GeV)",
355            x_label="Epoch",
356            boundaries=True,
357            boundary_labels=[f"Cat. {i}" for i in range(n_cats)],
358        )
359        plot_history(
360            np.array(continuum_history),
361            os.path.join(path_bins, f"continuum_background_{n_cats}_log.pdf"),
362            y_label="Continuum background (100-180 GeV)",
363            x_label="Epoch",
364            boundaries=True,
365            log_scale=True,
366            boundary_labels=[f"Cat. {i}" for i in range(n_cats)],
367        )
368        plot_bias_history(
369            bias_history,
370            os.path.join(path_bins, f"bias_history_{n_cats}.pdf"),
371            epochs=bias_epochs,
372            temp_points=temp_history,
373            temp_label="Temperature",
374        )
375
376        raw_assign = model.get_bin_indices(
377            {p: {"NN_output": tensor_data[p]["NN_output"]} for p in tensor_data}
378        )
379        raw_assign_np = {k: v.numpy() for k, v in raw_assign.items()}
380
381        assign_dict, order, _, inv = assign_bins_and_order(model, data_2d, reduce=False)
382        assign_np = {k: np.asarray(v) for k, v in assign_dict.items()}
383
384        z1_opt, z2_opt = compute_significances_from_assignments(
385            raw_assign_np,
386            data_2d,
387            n_cats,
388            sig_low,
389            sig_high,
390        )
391        gato_results["signal1"][n_cats] = z1_opt
392        gato_results["signal2"][n_cats] = z2_opt
393        per_cat_hists = build_category_mass_maps(raw_assign_np, data_2d, n_cats)
394        plot_category_mass_spectra(
395            per_cat_hists,
396            os.path.join(path_bins, "mass_spectra"),
397            sig_scales=(2, 10),
398        )
399
400        plot_bin_boundaries_2D(
401            model,
402            order,
403            os.path.join(path_bins, f"bin_boundaries_{n_cats}_bins.pdf"),
404        )
405
406        if boundary_frames:
407            gif_path = os.path.join(path_bins, f"boundary_evolution_{n_cats}.gif")
408            make_gif(boundary_frames, gif_path, interval=500)
409        B_sorted, rel_unc, _ = model.compute_hard_bkg_stats(
410            {p: {"NN_output": tensor_data[p]["NN_output"], "weight": tensor_data[p]["weight"]} for p in tensor_data},
411            signal_labels=["signal1", "signal2"],
412        )
413        plot_yield_vs_uncertainty(
414            B_sorted,
415            rel_unc,
416            output_filename=os.path.join(path_bins, f"yield_unc_{n_cats}.pdf"),
417        )
418        plot_yield_vs_uncertainty(
419            B_sorted,
420            rel_unc,
421            output_filename=os.path.join(path_bins, f"yield_unc_{n_cats}_log.pdf"),
422            log=True,
423        )
424
425    remapped_baseline = {
426        sig: {2 * n + 1: baseline_results[sig][n] for n in baseline_results[sig]}
427        for sig in baseline_results
428    }
429    plot_significance_comparison(
430        remapped_baseline,
431        gato_results,
432        os.path.join(path_plots, "significance_comparison.pdf"),
433    )
434
435if __name__ == "__main__":
436    main()