CbmRoot
Loading...
Searching...
No Matches
CbmKresTrainAnnDirectPhotons.cxx
Go to the documentation of this file.
1/* Copyright (C) 2017-2020 GSI Helmholtzzentrum fuer Schwerionenforschung, Darmstadt
2 SPDX-License-Identifier: GPL-3.0-only
3 Authors: Ievgenii Kres, Florian Uhlig [committer] */
4
23
24#include "CbmDrawHist.h"
25
26#include "TCanvas.h"
27#include "TH1D.h"
28#include "TH2D.h"
29#include "TMath.h"
30#include "TMultiLayerPerceptron.h"
31#include "TSystem.h"
32#include "TTree.h"
33
34#include <boost/assign/list_of.hpp>
35
36#include <iostream>
37#include <string>
38#include <vector>
39
40#include <cmath>
41
42
43using boost::assign::list_of;
44using namespace std;
45
47 : fMaxNofTrainSamples(1500)
48 , fAnnCut(0.6)
49 , fNofWrongLikeCorrect(0)
50 , fNofCorrectLikeWrong(0)
51 , IM_correct()
52 , OA_correct()
53 , Angle_correct()
54 , Z_correct()
55 , Mom1_correct()
56 , Mom2_correct()
57 , IM_wrong()
58 , OA_wrong()
59 , Angle_wrong()
60 , Z_wrong()
61 , Mom1_wrong()
62 , Mom2_wrong()
63 , fHists()
64 , fhAnnOutput_correct(nullptr)
65 , fhAnnOutput_wrong(nullptr)
66 , fhCumProb_correct(nullptr)
67 , fhCumProb_wrong(nullptr)
68{
69}
70
72
74
75void CbmKresTrainAnnDirectPhotons::Exec(int event, int IdForANN, double InvariantMass, double OpeningAngle,
76 double PlaneAngle_last, double ZPos, TVector3 Momentum1, TVector3 Momentum2)
77{
78 double p1 =
79 TMath::Sqrt(Momentum1.X() * Momentum1.X() + Momentum1.Y() * Momentum1.Y() + Momentum1.Z() * Momentum1.Z());
80 double p2 =
81 TMath::Sqrt(Momentum2.X() * Momentum2.X() + Momentum2.Y() * Momentum2.Y() + Momentum2.Z() * Momentum2.Z());
82 if (IdForANN == 1) {
83 IM_correct.push_back(InvariantMass);
84 OA_correct.push_back(OpeningAngle);
85 Angle_correct.push_back(PlaneAngle_last);
86 Z_correct.push_back(ZPos);
87 Mom1_correct.push_back(p1);
88 Mom2_correct.push_back(p2);
89 cout << "correct = " << IM_correct.size() << "; wrong = " << IM_wrong.size() << endl;
90 }
91 else {
92 IM_wrong.push_back(InvariantMass);
93 OA_wrong.push_back(OpeningAngle);
94 Angle_wrong.push_back(PlaneAngle_last);
95 Z_wrong.push_back(ZPos);
96 Mom1_wrong.push_back(p1);
97 Mom2_wrong.push_back(p2);
98 }
99
100 if (IM_correct.size() % 100 == 0 && IdForANN == 1)
101 cout << "correct = " << IM_correct.size() << "; wrong = " << IM_wrong.size() << endl;
102
103
104 if (event == 15000 && IM_correct.size() >= fMaxNofTrainSamples) {
106 Draw();
107
108 IM_correct.clear();
109 OA_correct.clear();
110 Angle_correct.clear();
111 Z_correct.clear();
112 Mom1_correct.clear();
113 Mom2_correct.clear();
114 IM_wrong.clear();
115 OA_wrong.clear();
116 Angle_wrong.clear();
117 Z_wrong.clear();
118 Mom1_wrong.clear();
119 Mom2_wrong.clear();
120 }
121}
122
123
125{
126 TTree* simu = new TTree("MonteCarlo", "MontecarloData");
127 Double_t x[6];
128 Double_t xOut;
129
130 simu->Branch("x0", &x[0], "x0/D");
131 simu->Branch("x1", &x[1], "x1/D");
132 simu->Branch("x2", &x[2], "x2/D");
133 simu->Branch("x3", &x[3], "x3/D");
134 simu->Branch("x4", &x[4], "x4/D");
135 simu->Branch("x5", &x[5], "x5/D");
136 simu->Branch("xOut", &xOut, "xOut/D");
137
138 for (size_t i = 0; i < IM_correct.size(); i++) {
139 x[0] = IM_correct[i] / 0.1;
140 x[1] = OA_correct[i] / 30;
141 x[2] = Angle_correct[i] / 30;
142 x[3] = Z_correct[i] / 100;
143 x[4] = Mom1_correct[i] / 5;
144 x[5] = Mom2_correct[i] / 5;
145
146 if (x[0] > 1.0) x[0] = 1.0;
147 if (x[1] > 1.0) x[1] = 1.0;
148 if (x[2] > 1.0) x[2] = 1.0;
149 if (x[3] > 1.0) x[3] = 1.0;
150 if (x[4] > 1.0) x[4] = 1.0;
151 if (x[5] > 1.0) x[5] = 1.0;
152
153 xOut = 1.;
154 simu->Fill();
155 if (i >= fMaxNofTrainSamples) break;
156 }
157 for (size_t i = 0; i < IM_wrong.size(); i++) {
158 x[0] = IM_wrong[i] / 0.1;
159 x[1] = OA_wrong[i] / 30;
160 x[2] = Angle_wrong[i] / 30;
161 x[3] = Z_wrong[i] / 100;
162 x[4] = Mom1_wrong[i] / 5;
163 x[5] = Mom2_wrong[i] / 5;
164
165 if (x[0] > 1.0) x[0] = 1.0;
166 if (x[1] > 1.0) x[1] = 1.0;
167 if (x[2] > 1.0) x[2] = 1.0;
168 if (x[3] > 1.0) x[3] = 1.0;
169 if (x[4] > 1.0) x[4] = 1.0;
170 if (x[5] > 1.0) x[5] = 1.0;
171
172 xOut = -1.;
173 simu->Fill();
174 if (i >= fMaxNofTrainSamples) break;
175 }
176
177 TMultiLayerPerceptron network("x0,x1,x2,x3,x4,x5:12:xOut", simu, "Entry$+1");
178 network.Train(300, "text,update=10");
179 network.DumpWeights("../../../analysis/conversion2/KresAnalysis_ann_photons_weights.txt");
180
181
182 Double_t params[6];
185
186 for (size_t i = 0; i < IM_correct.size(); i++) {
187 params[0] = IM_correct[i] / 0.1;
188 params[1] = OA_correct[i] / 30;
189 params[2] = Angle_correct[i] / 30;
190 params[3] = Z_correct[i] / 100;
191 params[4] = Mom1_correct[i] / 5;
192 params[5] = Mom2_correct[i] / 5;
193
194 if (params[0] > 1.0) params[0] = 1.0;
195 if (params[1] > 1.0) params[1] = 1.0;
196 if (params[2] > 1.0) params[2] = 1.0;
197 if (params[3] > 1.0) params[3] = 1.0;
198 if (params[4] > 1.0) params[4] = 1.0;
199 if (params[5] > 1.0) params[5] = 1.0;
200
201 Double_t netEval = network.Evaluate(0, params);
202 fhAnnOutput_correct->Fill(netEval);
203 if (netEval < fAnnCut) fNofCorrectLikeWrong++;
204 }
205 for (size_t i = 0; i < IM_wrong.size(); i++) {
206 params[0] = IM_wrong[i] / 0.1;
207 params[1] = OA_wrong[i] / 30;
208 params[2] = Angle_wrong[i] / 30;
209 params[3] = Z_wrong[i] / 100;
210 params[4] = Mom1_wrong[i] / 5;
211 params[5] = Mom2_wrong[i] / 5;
212
213 if (params[0] > 1.0) params[0] = 1.0;
214 if (params[1] > 1.0) params[1] = 1.0;
215 if (params[2] > 1.0) params[2] = 1.0;
216 if (params[3] > 1.0) params[3] = 1.0;
217 if (params[4] > 1.0) params[4] = 1.0;
218 if (params[5] > 1.0) params[5] = 1.0;
219
220 Double_t netEval = network.Evaluate(0, params);
221 fhAnnOutput_wrong->Fill(netEval);
222 if (netEval >= fAnnCut) fNofWrongLikeCorrect++;
223 }
224}
225
226
228{
229 cout << "nof correct pairs = " << IM_correct.size() << endl;
230 cout << "nof wrong pairs = " << IM_wrong.size() << endl;
231 cout << "wrong like correct = " << fNofWrongLikeCorrect
232 << ", wrong supp = " << (Double_t) IM_wrong.size() / fNofWrongLikeCorrect << endl;
233 cout << "Correct like wrong = " << fNofCorrectLikeWrong
234 << ", correct lost eff = " << 100. * (Double_t) fNofCorrectLikeWrong / IM_correct.size() << endl;
235
236 Double_t cumProbFake = 0.;
237 Double_t cumProbTrue = 0.;
238 Int_t nofTrue = (Int_t) fhAnnOutput_correct->GetEntries();
239 Int_t nofFake = (Int_t) fhAnnOutput_wrong->GetEntries();
240
241 for (Int_t i = 1; i <= fhAnnOutput_wrong->GetNbinsX(); i++) {
242 cumProbTrue += fhAnnOutput_correct->GetBinContent(i);
243 fhCumProb_correct->SetBinContent(i, 1. - (Double_t) cumProbTrue / nofTrue);
244
245 cumProbFake += fhAnnOutput_wrong->GetBinContent(i);
246 fhCumProb_wrong->SetBinContent(i, (Double_t) cumProbFake / nofFake);
247 }
248
249
250 TCanvas* c1 = new TCanvas("ann_correct_ann_output", "ann_correct_ann_output", 400, 400);
251 c1->SetTitle("ann_correct_ann_output");
252 fhAnnOutput_correct->Draw();
253
254 TCanvas* c2 = new TCanvas("ann_wrong_ann_output", "ann_wrong_ann_output", 400, 400);
255 c2->SetTitle("ann_wrong_ann_output");
256 fhAnnOutput_wrong->Draw();
257
258 TCanvas* c3 = new TCanvas("ann_correct_cum_prob", "ann_correct_cum_prob", 400, 400);
259 c3->SetTitle("ann_correct_cum_output");
260 fhCumProb_correct->Draw();
261
262 TCanvas* c4 = new TCanvas("ann_wrong_cum_prob", "ann_wrong_cum_prob", 400, 400);
263 c4->SetTitle("ann_wrong_cum_output");
264 fhCumProb_wrong->Draw();
265}
266
267
269{
270
271 fhAnnOutput_correct = new TH1D("fhAnnOutput_correct", "ANN output;ANN output;Counter", 100, -1.2, 1.2);
272 fHists.push_back(fhAnnOutput_correct);
273 fhAnnOutput_wrong = new TH1D("fhAnnOutput_wrong", "ANN output;ANN output;Counter", 100, -1.2, 1.2);
274 fHists.push_back(fhAnnOutput_wrong);
275
276 fhCumProb_correct = new TH1D("fhCumProb_correct", "ANN output;ANN output;Cumulative probability", 100, -1.2, 1.2);
277 fHists.push_back(fhCumProb_correct);
278 fhCumProb_wrong = new TH1D("fhCumProb_wrong", "ANN output;ANN output;Cumulative probability", 100, -1.2, 1.2);
279 fHists.push_back(fhCumProb_wrong);
280}
Helper functions for drawing 1D and 2D histograms and graphs.
void Exec(int event, int IdForANN, double InvariantMass, double OpeningAngle, double PlaneAngle_last, double ZPos, TVector3 Momentum1, TVector3 Momentum2)
Hash for CbmL1LinkKey.