Three-class softmax example#
Optimize categories directly on a 3-class softmax output (visualized in two dimensions). The script trains several Gaussian mixtures, compares to argmax baselines, and produces animations of the learned components.
Run:
python examples/three_class_softmax_example/run_example.py --gato-bins 5 10 20 --epochs 500
Output plots#
frames_<N>folders with boundary evolution frames (assembled into GIFs).Stacked histograms contrasting background compositions with scaled signal templates.
Yield vs. uncertainty bar charts for each bin.
Source code#
Source of the 3-class softmax example#
1import os
2import argparse
3import numpy as np
4import tensorflow as tf
5import tensorflow_probability as tfp
6
7from gatohep.data_generation import generate_toy_data_3class_3D
8from gatohep.losses import high_bkg_uncertainty_penalty, low_bkg_penalty
9from gatohep.models import gato_gmm_model
10from gatohep.plotting_utils import (
11 assign_bins_and_order,
12 fill_histogram_from_assignments,
13 plot_bias_history,
14 plot_bin_boundaries_2D,
15 plot_history,
16 plot_learned_gaussians,
17 plot_significance_comparison,
18 plot_stacked_histograms,
19 plot_yield_vs_uncertainty,
20 make_gif
21)
22from gatohep.utils import (
23 LearningRateScheduler,
24 TemperatureScheduler,
25 asymptotic_significance,
26 compute_significance_from_hists,
27 create_hist
28)
29
30tfd = tfp.distributions
31
32
33def convert_data_to_tensors(data):
34 tensor_data = {}
35 for proc, df in data.items():
36 nn = np.stack(df["NN_output"].values)[:, :2]
37 w = df["weight"].values
38 tensor_data[proc] = {
39 "NN_output": tf.constant(nn, dtype=tf.float32),
40 "weight" : tf.constant(w , dtype=tf.float32),
41 }
42 return tensor_data
43
44
45# 2-D GMM gato model
46class gato_2D(gato_gmm_model):
47 def __init__(self, n_cats, temperature=0.3, name="gato_2D"):
48 super().__init__(
49 n_cats=n_cats,
50 dim=2,
51 temperature=temperature,
52 mean_norm="softmax",
53 cov_offdiag_damping=0.1,
54 name=name
55 )
56
57 def call(self, data_dict):
58 """
59 Compute the training loss and background yields,
60 which can be used for penalty terms.
61 """
62 significances, bkg_yield, bkg_sum_w2 = self.get_differentiable_significance(
63 data_dict,
64 signal_labels=["signal1", "signal2"],
65 return_details=True,
66 )
67 Z1 = significances["signal1"]
68 Z2 = significances["signal2"]
69 loss = -tf.sqrt(Z1 * Z2)
70 return loss, bkg_yield, bkg_sum_w2
71
72
73# Main function to run the example
74def main():
75 parser = argparse.ArgumentParser(description="2-D soft-max GATO optimisation")
76 parser.add_argument("--epochs", type=int, default=250)
77 parser.add_argument("--gato-bins", nargs="+", type=int, default=[3,5,10])
78 parser.add_argument("--lam-yield", type=float, default=0.)
79 parser.add_argument("--lam-unc", type=float, default=0.)
80 parser.add_argument("--thr-yield", type=float, default=5.)
81 parser.add_argument("--thr-unc", type=float, default=0.20)
82 parser.add_argument("--n-bkg", type=int, default=1_000_000)
83 parser.add_argument("--out", type=str, default="Plots")
84 args = parser.parse_args()
85
86 path_plots = f'./examples/three_class_softmax_example/{args.out}/'
87 os.makedirs(path_plots, exist_ok=True)
88
89 # ---------- toy data
90 data = generate_toy_data_3class_3D(
91 n_signal1=100_000, n_signal2=100_000,
92 n_bkg=args.n_bkg,
93 seed=42
94 )
95 tensor_data = convert_data_to_tensors(data)
96
97 # dataset plots (only 2 dims kept)
98 # plot to show dataset
99 for dim in range(3): # show 0, 1 and 2
100 _h = {}
101 for proc, df in data.items():
102 vals = np.stack(df['NN_output'].values)[:, dim] # full 3-D array here
103 _h[proc] = create_hist(
104 vals, df['weight'].values, bins=50, low=0.0, high=1.0
105 )
106 for proc, df in data.items():
107 vals = np.stack(df["NN_output"].values)[:,dim]
108 _h[proc] = create_hist(vals, df["weight"].values, bins=50, low=0., high=1.)
109 for log in (False, True):
110 suf = "_log" if log else ""
111 plot_stacked_histograms(
112 stacked_hists=[_h[p] for p in data if not p.startswith("signal")],
113 process_labels=[p for p in data if not p.startswith("signal")],
114 signal_hists=[100*_h["signal1"], 500*_h["signal2"]],
115 signal_labels=['Signal1 x 100', 'Signal2 x 500'],
116 log=log,
117 output_filename=os.path.join(path_plots, f"data_dim{dim}{suf}.pdf"),
118 axis_labels=(f"soft-max dim {dim}", "Events"),
119 )
120
121 # argmax-classification as comparison
122 baseline_results = {'signal1':{}, 'signal2':{}}
123 for nb in [2, 5, 10]:
124 h_sig1 = None
125 bkg1 = []
126 lbl1 = []
127 for p,df in data.items():
128 v = np.stack(df["NN_output"].values)[:,0]
129 m = np.argmax(np.stack(df["NN_output"].values), 1) == 0
130 if p == 'signal1':
131 h_sig1 = create_hist(
132 v[m], df["weight"].values[m], bins=nb, low=0.33, high=1.
133 )
134 else:
135 bkg1.append(
136 create_hist(
137 v[m], df["weight"].values[m], bins=nb, low=0.33, high=1.
138 )
139 )
140 lbl1.append(p)
141 baseline_results['signal1'][nb] = compute_significance_from_hists(h_sig1,bkg1)
142
143 h_sig2 = None
144 bkg2 = []
145 lbl2 = []
146 for p,df in data.items():
147 v = np.stack(df["NN_output"].values)[:,1]
148 m = np.argmax(np.stack(df["NN_output"].values),1) == 1
149 if p == 'signal2':
150 h_sig2 = create_hist(
151 v[m], df["weight"].values[m], bins=nb, low=0.33, high=1.
152 )
153 else:
154 bkg2.append(
155 create_hist(
156 v[m], df["weight"].values[m], bins=nb, low=0.33, high=1.
157 )
158 )
159 lbl2.append(p)
160 baseline_results['signal2'][nb] = compute_significance_from_hists(h_sig2,bkg2)
161
162 # GATO optimisation below
163 gato_results = {'signal1':{}, 'signal2':{}}
164 path_gato = os.path.join(path_plots, "gato")
165 os.makedirs(path_gato, exist_ok=True)
166
167 # for validations, use not tf tensors
168 data_2d = {}
169 for proc, df in data.items():
170 df2 = df.copy()
171 df2["NN_output"] = [v[:2] for v in df["NN_output"].values]
172 data_2d[proc] = df2
173 sig1_scale = 100
174 sig2_scale = 500
175
176 for n_cats in args.gato_bins:
177
178 @tf.function
179 def train_step(model, tdata, opt, lamY, lamU, thrY, thrU):
180 with tf.GradientTape() as tape:
181 loss, B, Bw2 = model.call(tdata)
182 penY = low_bkg_penalty(B, threshold=thrY)
183 penU = high_bkg_uncertainty_penalty(Bw2, B, rel_threshold=thrU)
184 total = loss + lamY*penY + lamU*penU
185 g = tape.gradient(total, model.trainable_variables)
186 opt.apply_gradients(zip(g, model.trainable_variables))
187 return loss
188
189 model = gato_2D(n_cats=n_cats, temperature=1.0)
190
191 optimizer = tf.keras.optimizers.RMSprop(0.1)
192 lr_scheduler = LearningRateScheduler(
193 optimizer,
194 lr_initial=0.05,
195 lr_final=0.001,
196 total_epochs=args.epochs,
197 mode="cosine",
198 )
199
200 # temperature scheduler
201 temperature_scheduler = TemperatureScheduler(
202 model,
203 t_initial=1.0,
204 t_final=0.05,
205 total_epochs=args.epochs,
206 mode="cosine",
207 )
208
209 loss_history = []
210 boundary_frames = []
211 hist_frames = []
212 mean_bias_history = []
213 bias_epochs = []
214 temperature_history = []
215 for ep in range(args.epochs):
216 lr_scheduler.update(ep)
217 temperature_scheduler.update(ep)
218 loss = train_step(
219 model, tensor_data, optimizer,
220 args.lam_yield, args.lam_unc,
221 args.thr_yield, args.thr_unc
222 )
223
224 if ep % 25 == 0 or ep == args.epochs - 1:
225 bias_vec = model.get_bias(tensor_data)
226 mean_bias_history.append(float(np.mean(np.abs(bias_vec))))
227 bias_epochs.append(ep)
228 temperature_history.append(float(model.temperature))
229
230 if ep % 25 == 0:
231 lr_value = getattr(optimizer, "learning_rate", getattr(optimizer, "lr", None))
232 if hasattr(lr_value, "numpy"):
233 lr_value = float(lr_value.numpy())
234 else:
235 lr_value = float(lr_value)
236 print(f"[{ep:03d}] loss = {loss.numpy():.3f}, lr = {lr_value:.5f}")
237 assign, order, _, inv = assign_bins_and_order(
238 model, data_2d, reduce=True
239 )
240 filled = {p: fill_histogram_from_assignments(
241 assign[p], data_2d[p]["weight"], n_cats
242 ) for p in data_2d}
243 bg_procs = [p for p in data if not p.startswith("signal")]
244 opt_bkgs = [filled[p] for p in bg_procs]
245 # 1) histogram
246 hist_fname = path_gato + f"/progress_plots_{n_cats}/hist_{ep:04d}.png"
247 plot_stacked_histograms(
248 stacked_hists=opt_bkgs,
249 process_labels=bg_procs,
250 signal_hists=[
251 sig1_scale * filled["signal1"],
252 sig2_scale * filled["signal2"],
253 ],
254 signal_labels=[f"Signal1 x{sig1_scale}", f"Signal2 x{sig2_scale}"],
255 output_filename=hist_fname,
256 axis_labels=("Bin index", "Events"),
257 normalize=False,
258 log=False,
259 )
260 hist_frames.append(hist_fname)
261
262 # 2) boundaries
263 boundary_fname = path_gato + f"/frames_{n_cats}/boundary_{ep:04d}.png"
264 plot_bin_boundaries_2D(
265 model,
266 [i for i in range(n_cats)],
267 boundary_fname,
268 resolution=500,
269 annotation=f"Epoch {ep}",
270 )
271 boundary_frames.append(boundary_fname)
272 loss_history.append(loss.numpy())
273
274 checkpoint_dir = os.path.join(path_gato, "checkpoints", f"{n_cats}_bins")
275 os.makedirs(checkpoint_dir, exist_ok=True)
276 model.save(checkpoint_dir)
277
278 bias_plot_base = os.path.join(path_gato, f"bias_history_{n_cats}bins")
279 plot_bias_history(
280 mean_bias_history,
281 bias_plot_base + ".pdf",
282 epochs=bias_epochs,
283 temp_points=temperature_history,
284 temp_label="Temperature",
285 )
286 plot_bias_history(
287 mean_bias_history,
288 bias_plot_base + "_log.pdf",
289 epochs=bias_epochs,
290 temp_points=temperature_history,
291 temp_label="Temperature",
292 log_scale=True,
293 )
294
295 # check bias due to finite temperature in training
296 bias = model.get_bias(tensor_data)
297 print(f"T = {model.temperature:4.2f}; per-bin bias: {bias}")
298
299 assign, order, _, inv = assign_bins_and_order(model, data_2d, reduce=True)
300
301 # 2) make per-process histograms
302 filled = {p: fill_histogram_from_assignments(
303 assign[p], data_2d[p]["weight"], n_cats
304 ) for p in data_2d}
305
306 opt_bkgs = [filled[f"bkg{i}"] for i in range(1,6)]
307 Z1 = compute_significance_from_hists(
308 filled["signal1"], opt_bkgs+[filled["signal2"]]
309 )
310 Z2 = compute_significance_from_hists(
311 filled["signal2"], opt_bkgs+[filled["signal1"]]
312 )
313 gato_results['signal1'][n_cats] = Z1
314 gato_results['signal2'][n_cats] = Z2
315
316 # quick plots
317 plot_learned_gaussians(
318 data=data, model=model, dim_x=0, dim_y=1,
319 output_filename=os.path.join(path_gato, f"Gaussians_{n_cats}bins.pdf"),
320 inv_mapping=inv,
321 )
322
323 plot_bin_boundaries_2D(
324 model,
325 order,
326 path_plot=os.path.join(path_gato, f"Bin_boundaries_{n_cats}_bins.pdf")
327 )
328
329 # loss curve
330 plot_history(
331 np.array(loss_history),
332 os.path.join(path_gato, f"loss_{n_cats}.pdf"),
333 y_label=r"Geometric mean $(Z_1,Z_2)$", x_label="Epoch"
334 )
335
336 # 1) Stacked histogram of optimized bins:
337 # Collect background processes
338 bg_procs = [p for p in data if not p.startswith("signal")]
339 opt_bkgs = [filled[p] for p in bg_procs]
340
341 for use_log in (False, True):
342 suffix = "log" if use_log else "linear"
343
344 plot_stacked_histograms(
345 stacked_hists=opt_bkgs,
346 process_labels=bg_procs,
347 signal_hists=[
348 sig1_scale * filled["signal1"],
349 sig2_scale * filled["signal2"]
350 ],
351 signal_labels=[
352 f"Signal1 x{sig1_scale}",
353 f"Signal2 x{sig2_scale}"
354 ],
355 output_filename=os.path.join(
356 path_gato, f"optimized_dist_{n_cats}bins_{suffix}.pdf"
357 ),
358 axis_labels=("Bin index", "Events"),
359 normalize=False,
360 log=use_log
361 )
362 print(
363 f"Saved optimized {suffix} histogram:\
364 optimized_dist_{n_cats}bins_{suffix}.pdf"
365 )
366
367 B_sorted, rel_unc_sorted, _ = model.compute_hard_bkg_stats(tensor_data)
368 B_ord = B_sorted[order]
369 unc_ord = rel_unc_sorted[order]
370
371 for use_log in (False, True):
372 suffix = "log" if use_log else "linear"
373
374 plot_yield_vs_uncertainty(
375 B_ord,
376 unc_ord,
377 log=use_log,
378 output_filename=os.path.join(
379 path_gato, f"yield_vs_unc_{n_cats}bins_{suffix}.pdf"
380 )
381 )
382 print(
383 f"Saved yield vs. unc ({suffix}):\
384 yield_vs_unc_{n_cats}bins_{suffix}.pdf"
385 )
386
387 make_gif(
388 hist_frames, path_gato + f"/frames_{n_cats}/hist_evolution.gif"
389 )
390 make_gif(
391 boundary_frames, path_gato + f"/frames_{n_cats}/boundaries_evolution.gif"
392 )
393
394 # summary comparison
395 plot_significance_comparison(
396 baseline_results={
397 k:{
398 2*n+1:baseline_results[k][n] for n in baseline_results[k]
399 } for k in baseline_results
400 },
401 optimized_results=gato_results,
402 output_filename=os.path.join(path_gato, "significance_comparison.pdf"),
403 )
404
405
406if __name__ == "__main__":
407 main()