Three-class softmax example#

Optimize categories directly on a 3-class softmax output (visualized in two dimensions). The script trains several Gaussian mixtures, compares to argmax baselines, and produces animations of the learned components.

Run:

python examples/three_class_softmax_example/run_example.py --gato-bins 5 10 20 --epochs 500

Output plots#

  • frames_<N> folders with boundary evolution frames (assembled into GIFs).

  • Stacked histograms contrasting background compositions with scaled signal templates.

  • Yield vs. uncertainty bar charts for each bin.

Source code#

Source of the 3-class softmax example#
  1import os
  2import argparse
  3import numpy as np
  4import tensorflow as tf
  5import tensorflow_probability as tfp
  6
  7from gatohep.data_generation import generate_toy_data_3class_3D
  8from gatohep.losses import high_bkg_uncertainty_penalty, low_bkg_penalty
  9from gatohep.models import gato_gmm_model
 10from gatohep.plotting_utils import (
 11    assign_bins_and_order,
 12    fill_histogram_from_assignments,
 13    plot_bias_history,
 14    plot_bin_boundaries_2D,
 15    plot_history,
 16    plot_learned_gaussians,
 17    plot_significance_comparison,
 18    plot_stacked_histograms,
 19    plot_yield_vs_uncertainty,
 20    make_gif
 21)
 22from gatohep.utils import (
 23    LearningRateScheduler,
 24    TemperatureScheduler,
 25    asymptotic_significance,
 26    compute_significance_from_hists,
 27    create_hist
 28)
 29
 30tfd = tfp.distributions
 31
 32
 33def convert_data_to_tensors(data):
 34    tensor_data = {}
 35    for proc, df in data.items():
 36        nn = np.stack(df["NN_output"].values)[:, :2]
 37        w = df["weight"].values
 38        tensor_data[proc] = {
 39            "NN_output": tf.constant(nn, dtype=tf.float32),
 40            "weight"   : tf.constant(w , dtype=tf.float32),
 41        }
 42    return tensor_data
 43
 44
 45#  2-D GMM gato model
 46class gato_2D(gato_gmm_model):
 47    def __init__(self, n_cats, temperature=0.3, name="gato_2D"):
 48        super().__init__(
 49            n_cats=n_cats,
 50            dim=2,
 51            temperature=temperature,
 52            mean_norm="softmax",
 53            cov_offdiag_damping=0.1,
 54            name=name
 55        )
 56
 57    def call(self, data_dict):
 58        """
 59            Compute the training loss and background yields,
 60            which can be used for penalty terms.
 61        """
 62        significances, bkg_yield, bkg_sum_w2 = self.get_differentiable_significance(
 63            data_dict,
 64            signal_labels=["signal1", "signal2"],
 65            return_details=True,
 66        )
 67        Z1 = significances["signal1"]
 68        Z2 = significances["signal2"]
 69        loss = -tf.sqrt(Z1 * Z2)
 70        return loss, bkg_yield, bkg_sum_w2
 71
 72
 73# Main function to run the example
 74def main():
 75    parser = argparse.ArgumentParser(description="2-D soft-max GATO optimisation")
 76    parser.add_argument("--epochs",      type=int,   default=250)
 77    parser.add_argument("--gato-bins",   nargs="+",  type=int, default=[3,5,10])
 78    parser.add_argument("--lam-yield",   type=float, default=0.)
 79    parser.add_argument("--lam-unc",     type=float, default=0.)
 80    parser.add_argument("--thr-yield",   type=float, default=5.)
 81    parser.add_argument("--thr-unc",     type=float, default=0.20)
 82    parser.add_argument("--n-bkg",       type=int,   default=1_000_000)
 83    parser.add_argument("--out",         type=str,   default="Plots")
 84    args = parser.parse_args()
 85
 86    path_plots = f'./examples/three_class_softmax_example/{args.out}/'
 87    os.makedirs(path_plots, exist_ok=True)
 88
 89    # ---------- toy data
 90    data = generate_toy_data_3class_3D(
 91        n_signal1=100_000, n_signal2=100_000,
 92        n_bkg=args.n_bkg,
 93        seed=42
 94    )
 95    tensor_data = convert_data_to_tensors(data)
 96
 97    # dataset plots (only 2 dims kept)
 98    # plot to show dataset
 99    for dim in range(3):              # show 0, 1 and 2
100        _h = {}
101        for proc, df in data.items():
102            vals = np.stack(df['NN_output'].values)[:, dim]   # full 3-D array here
103            _h[proc] = create_hist(
104                vals, df['weight'].values, bins=50, low=0.0, high=1.0
105            )
106        for proc, df in data.items():
107            vals = np.stack(df["NN_output"].values)[:,dim]
108            _h[proc] = create_hist(vals, df["weight"].values, bins=50, low=0., high=1.)
109        for log in (False, True):
110            suf = "_log" if log else ""
111            plot_stacked_histograms(
112                stacked_hists=[_h[p] for p in data if not p.startswith("signal")],
113                process_labels=[p for p in data if not p.startswith("signal")],
114                signal_hists=[100*_h["signal1"], 500*_h["signal2"]],
115                signal_labels=['Signal1 x 100', 'Signal2 x 500'],
116                log=log,
117                output_filename=os.path.join(path_plots, f"data_dim{dim}{suf}.pdf"),
118                axis_labels=(f"soft-max dim {dim}", "Events"),
119            )
120
121    # argmax-classification as comparison
122    baseline_results = {'signal1':{}, 'signal2':{}}
123    for nb in [2, 5, 10]:
124        h_sig1 = None
125        bkg1 = []
126        lbl1 = []
127        for p,df in data.items():
128            v = np.stack(df["NN_output"].values)[:,0]
129            m = np.argmax(np.stack(df["NN_output"].values), 1) == 0
130            if p == 'signal1':
131                h_sig1 = create_hist(
132                    v[m], df["weight"].values[m], bins=nb, low=0.33, high=1.
133                )
134            else:
135                bkg1.append(
136                    create_hist(
137                        v[m], df["weight"].values[m], bins=nb, low=0.33, high=1.
138                        )
139                )
140                lbl1.append(p)
141        baseline_results['signal1'][nb] = compute_significance_from_hists(h_sig1,bkg1)
142
143        h_sig2 = None
144        bkg2 = []
145        lbl2 = []
146        for p,df in data.items():
147            v = np.stack(df["NN_output"].values)[:,1]
148            m = np.argmax(np.stack(df["NN_output"].values),1) == 1
149            if p == 'signal2':
150                h_sig2 = create_hist(
151                    v[m], df["weight"].values[m], bins=nb, low=0.33, high=1.
152                )
153            else:
154                bkg2.append(
155                    create_hist(
156                        v[m], df["weight"].values[m], bins=nb, low=0.33, high=1.
157                    )
158                )
159                lbl2.append(p)
160        baseline_results['signal2'][nb] = compute_significance_from_hists(h_sig2,bkg2)
161
162    # GATO optimisation below
163    gato_results = {'signal1':{}, 'signal2':{}}
164    path_gato = os.path.join(path_plots, "gato")
165    os.makedirs(path_gato, exist_ok=True)
166
167    # for validations, use not tf tensors
168    data_2d = {}
169    for proc, df in data.items():
170        df2 = df.copy()
171        df2["NN_output"] = [v[:2] for v in df["NN_output"].values]
172        data_2d[proc] = df2
173    sig1_scale = 100
174    sig2_scale = 500
175
176    for n_cats in args.gato_bins:
177
178        @tf.function
179        def train_step(model, tdata, opt, lamY, lamU, thrY, thrU):
180            with tf.GradientTape() as tape:
181                loss, B, Bw2 = model.call(tdata)
182                penY = low_bkg_penalty(B, threshold=thrY)
183                penU = high_bkg_uncertainty_penalty(Bw2, B, rel_threshold=thrU)
184                total = loss + lamY*penY + lamU*penU
185            g = tape.gradient(total, model.trainable_variables)
186            opt.apply_gradients(zip(g, model.trainable_variables))
187            return loss
188
189        model = gato_2D(n_cats=n_cats, temperature=1.0)
190
191        optimizer = tf.keras.optimizers.RMSprop(0.1)
192        lr_scheduler = LearningRateScheduler(
193            optimizer,
194            lr_initial=0.05,
195            lr_final=0.001,
196            total_epochs=args.epochs,
197            mode="cosine",
198        )
199
200        # temperature scheduler
201        temperature_scheduler = TemperatureScheduler(
202            model,
203            t_initial=1.0,
204            t_final=0.05,
205            total_epochs=args.epochs,
206            mode="cosine",
207        )
208
209        loss_history = []
210        boundary_frames = []
211        hist_frames = []
212        mean_bias_history = []
213        bias_epochs = []
214        temperature_history = []
215        for ep in range(args.epochs):
216            lr_scheduler.update(ep)
217            temperature_scheduler.update(ep)
218            loss = train_step(
219                model, tensor_data, optimizer,
220                args.lam_yield, args.lam_unc,
221                args.thr_yield, args.thr_unc
222            )
223
224            if ep % 25 == 0 or ep == args.epochs - 1:
225                bias_vec = model.get_bias(tensor_data)
226                mean_bias_history.append(float(np.mean(np.abs(bias_vec))))
227                bias_epochs.append(ep)
228                temperature_history.append(float(model.temperature))
229
230            if ep % 25 == 0:
231                lr_value = getattr(optimizer, "learning_rate", getattr(optimizer, "lr", None))
232                if hasattr(lr_value, "numpy"):
233                    lr_value = float(lr_value.numpy())
234                else:
235                    lr_value = float(lr_value)
236                print(f"[{ep:03d}] loss = {loss.numpy():.3f}, lr = {lr_value:.5f}")
237                assign, order, _, inv = assign_bins_and_order(
238                    model, data_2d, reduce=True
239                )
240                filled = {p: fill_histogram_from_assignments(
241                    assign[p], data_2d[p]["weight"], n_cats
242                ) for p in data_2d}
243                bg_procs = [p for p in data if not p.startswith("signal")]
244                opt_bkgs = [filled[p] for p in bg_procs]
245                # 1) histogram
246                hist_fname = path_gato + f"/progress_plots_{n_cats}/hist_{ep:04d}.png"
247                plot_stacked_histograms(
248                    stacked_hists=opt_bkgs,
249                    process_labels=bg_procs,
250                    signal_hists=[
251                        sig1_scale * filled["signal1"],
252                        sig2_scale * filled["signal2"],
253                    ],
254                    signal_labels=[f"Signal1 x{sig1_scale}", f"Signal2 x{sig2_scale}"],
255                    output_filename=hist_fname,
256                    axis_labels=("Bin index", "Events"),
257                    normalize=False,
258                    log=False,
259                )
260                hist_frames.append(hist_fname)
261
262                # 2) boundaries
263                boundary_fname = path_gato + f"/frames_{n_cats}/boundary_{ep:04d}.png"
264                plot_bin_boundaries_2D(
265                    model,
266                    [i for i in range(n_cats)],
267                    boundary_fname,
268                    resolution=500,
269                    annotation=f"Epoch {ep}",
270                )
271                boundary_frames.append(boundary_fname)
272            loss_history.append(loss.numpy())
273
274        checkpoint_dir = os.path.join(path_gato, "checkpoints", f"{n_cats}_bins")
275        os.makedirs(checkpoint_dir, exist_ok=True)
276        model.save(checkpoint_dir)
277
278        bias_plot_base = os.path.join(path_gato, f"bias_history_{n_cats}bins")
279        plot_bias_history(
280            mean_bias_history,
281            bias_plot_base + ".pdf",
282            epochs=bias_epochs,
283            temp_points=temperature_history,
284            temp_label="Temperature",
285        )
286        plot_bias_history(
287            mean_bias_history,
288            bias_plot_base + "_log.pdf",
289            epochs=bias_epochs,
290            temp_points=temperature_history,
291            temp_label="Temperature",
292            log_scale=True,
293        )
294
295        # check bias due to finite temperature in training
296        bias = model.get_bias(tensor_data)
297        print(f"T = {model.temperature:4.2f};  per-bin bias: {bias}")
298
299        assign, order, _, inv = assign_bins_and_order(model, data_2d, reduce=True)
300
301        # 2) make per-process histograms
302        filled = {p: fill_histogram_from_assignments(
303            assign[p], data_2d[p]["weight"], n_cats
304        ) for p in data_2d}
305
306        opt_bkgs = [filled[f"bkg{i}"] for i in range(1,6)]
307        Z1 = compute_significance_from_hists(
308            filled["signal1"], opt_bkgs+[filled["signal2"]]
309        )
310        Z2 = compute_significance_from_hists(
311            filled["signal2"], opt_bkgs+[filled["signal1"]]
312        )
313        gato_results['signal1'][n_cats] = Z1
314        gato_results['signal2'][n_cats] = Z2
315
316        # quick plots
317        plot_learned_gaussians(
318            data=data, model=model, dim_x=0, dim_y=1,
319            output_filename=os.path.join(path_gato, f"Gaussians_{n_cats}bins.pdf"),
320            inv_mapping=inv,
321        )
322
323        plot_bin_boundaries_2D(
324            model,
325            order,
326            path_plot=os.path.join(path_gato, f"Bin_boundaries_{n_cats}_bins.pdf")
327        )
328
329        # loss curve
330        plot_history(
331            np.array(loss_history),
332            os.path.join(path_gato, f"loss_{n_cats}.pdf"),
333            y_label=r"Geometric mean $(Z_1,Z_2)$", x_label="Epoch"
334        )
335
336        # 1) Stacked histogram of optimized bins:
337        # Collect background processes
338        bg_procs = [p for p in data if not p.startswith("signal")]
339        opt_bkgs = [filled[p] for p in bg_procs]
340
341        for use_log in (False, True):
342            suffix = "log" if use_log else "linear"
343
344            plot_stacked_histograms(
345                stacked_hists=opt_bkgs,
346                process_labels=bg_procs,
347                signal_hists=[
348                    sig1_scale * filled["signal1"],
349                    sig2_scale * filled["signal2"]
350                ],
351                signal_labels=[
352                    f"Signal1 x{sig1_scale}",
353                    f"Signal2 x{sig2_scale}"
354                ],
355                output_filename=os.path.join(
356                    path_gato, f"optimized_dist_{n_cats}bins_{suffix}.pdf"
357                ),
358                axis_labels=("Bin index", "Events"),
359                normalize=False,
360                log=use_log
361            )
362            print(
363                f"Saved optimized {suffix} histogram:\
364                optimized_dist_{n_cats}bins_{suffix}.pdf"
365            )
366
367        B_sorted, rel_unc_sorted, _ = model.compute_hard_bkg_stats(tensor_data)
368        B_ord = B_sorted[order]
369        unc_ord = rel_unc_sorted[order]
370
371        for use_log in (False, True):
372            suffix = "log" if use_log else "linear"
373
374            plot_yield_vs_uncertainty(
375                B_ord,
376                unc_ord,
377                log=use_log,
378                output_filename=os.path.join(
379                    path_gato, f"yield_vs_unc_{n_cats}bins_{suffix}.pdf"
380                )
381            )
382            print(
383                f"Saved yield vs. unc ({suffix}):\
384                yield_vs_unc_{n_cats}bins_{suffix}.pdf"
385            )
386
387        make_gif(
388            hist_frames, path_gato + f"/frames_{n_cats}/hist_evolution.gif"
389        )
390        make_gif(
391            boundary_frames, path_gato + f"/frames_{n_cats}/boundaries_evolution.gif"
392        )
393
394    # summary comparison
395    plot_significance_comparison(
396        baseline_results={
397            k:{
398                2*n+1:baseline_results[k][n] for n in baseline_results[k]
399            } for k in baseline_results
400        },
401        optimized_results=gato_results,
402        output_filename=os.path.join(path_gato, "significance_comparison.pdf"),
403    )
404
405
406if __name__ == "__main__":
407    main()