CbmRoot
Loading...
Searching...
No Matches
CbmKresTrainAnn.cxx
Go to the documentation of this file.
1/* Copyright (C) 2017-2020 GSI Helmholtzzentrum fuer Schwerionenforschung, Darmstadt
2 SPDX-License-Identifier: GPL-3.0-only
3 Authors: Ievgenii Kres, Florian Uhlig [committer] */
4
22#include "CbmKresTrainAnn.h"
23
24#include "CbmDrawHist.h"
25
26#include "TCanvas.h"
27#include "TH1D.h"
28#include "TH2D.h"
29#include "TMath.h"
30#include "TMultiLayerPerceptron.h"
31#include "TSystem.h"
32#include "TTree.h"
33
34#include <boost/assign/list_of.hpp>
35
36#include <iostream>
37#include <string>
38#include <vector>
39
40#include <cmath>
41
42
43using boost::assign::list_of;
44using namespace std;
45
47 : fMaxNofTrainSamples(1500)
48 , fAnnCut(0.6)
49 , fNofWrongLikeCorrect(0)
50 , fNofCorrectLikeWrong(0)
51 , IM_correct()
52 , OA_correct()
53 , Angle_correct()
54 , Z_correct()
55 , Mom1_correct()
56 , Mom2_correct()
57 , IM_wrong()
58 , OA_wrong()
59 , Angle_wrong()
60 , Z_wrong()
61 , Mom1_wrong()
62 , Mom2_wrong()
63 , fHists()
64 , fhAnnOutput_correct(nullptr)
65 , fhAnnOutput_wrong(nullptr)
66 , fhCumProb_correct(nullptr)
67 , fhCumProb_wrong(nullptr)
68{
69}
70
72
74
75void CbmKresTrainAnn::Exec(int event, int IdForANN, double InvariantMass, double OpeningAngle, double PlaneAngle_last,
76 double ZPos, TVector3 Momentum1, TVector3 Momentum2)
77{
78 double p1 =
79 TMath::Sqrt(Momentum1.X() * Momentum1.X() + Momentum1.Y() * Momentum1.Y() + Momentum1.Z() * Momentum1.Z());
80 double p2 =
81 TMath::Sqrt(Momentum2.X() * Momentum2.X() + Momentum2.Y() * Momentum2.Y() + Momentum2.Z() * Momentum2.Z());
82 if (IdForANN == 1) {
83 //if (IM_correct.size() < fMaxNofTrainSamples){
84 IM_correct.push_back(InvariantMass);
85 OA_correct.push_back(OpeningAngle);
86 Angle_correct.push_back(PlaneAngle_last);
87 Z_correct.push_back(ZPos);
88 Mom1_correct.push_back(p1);
89 Mom2_correct.push_back(p2);
90 //}
91 }
92 else {
93 //if (IM_wrong.size() < fMaxNofTrainSamples){
94 IM_wrong.push_back(InvariantMass);
95 OA_wrong.push_back(OpeningAngle);
96 Angle_wrong.push_back(PlaneAngle_last);
97 Z_wrong.push_back(ZPos);
98 Mom1_wrong.push_back(p1);
99 Mom2_wrong.push_back(p2);
100 //}
101 }
102
103 if (IM_correct.size() % 100 == 0 && IdForANN == 1)
104 cout << "correct = " << IM_correct.size() << "; wrong = " << IM_wrong.size() << endl;
105
106 //if (IM_correct.size() >= fMaxNofTrainSamples && IM_wrong.size() >= fMaxNofTrainSamples) {
107 if (event == 2000 && IM_correct.size() >= fMaxNofTrainSamples) {
109 Draw();
110
111 IM_correct.clear();
112 OA_correct.clear();
113 Angle_correct.clear();
114 Z_correct.clear();
115 Mom1_correct.clear();
116 Mom2_correct.clear();
117 IM_wrong.clear();
118 OA_wrong.clear();
119 Angle_wrong.clear();
120 Z_wrong.clear();
121 Mom1_wrong.clear();
122 Mom2_wrong.clear();
123 }
124}
125
127{
128 //cout << "Do TrainAndTestAnn" << endl;
129 TTree* simu = new TTree("MonteCarlo", "MontecarloData");
130 Double_t x[6];
131 Double_t xOut;
132
133 simu->Branch("x0", &x[0], "x0/D");
134 simu->Branch("x1", &x[1], "x1/D");
135 simu->Branch("x2", &x[2], "x2/D");
136 simu->Branch("x3", &x[3], "x3/D");
137 simu->Branch("x4", &x[4], "x4/D");
138 simu->Branch("x5", &x[5], "x5/D");
139 simu->Branch("xOut", &xOut, "xOut/D");
140
141 for (size_t i = 0; i < IM_correct.size(); i++) {
142 x[0] = IM_correct[i] / 0.1;
143 x[1] = OA_correct[i] / 30;
144 x[2] = Angle_correct[i] / 30;
145 x[3] = Z_correct[i] / 100;
146 x[4] = Mom1_correct[i] / 5;
147 x[5] = Mom2_correct[i] / 5;
148
149 if (x[0] > 1.0) x[0] = 1.0;
150 if (x[1] > 1.0) x[1] = 1.0;
151 if (x[2] > 1.0) x[2] = 1.0;
152 if (x[3] > 1.0) x[3] = 1.0;
153 if (x[4] > 1.0) x[4] = 1.0;
154 if (x[5] > 1.0) x[5] = 1.0;
155
156 xOut = 1.;
157 simu->Fill();
158 if (i >= fMaxNofTrainSamples) break;
159 }
160 for (size_t i = 0; i < IM_wrong.size(); i++) {
161 x[0] = IM_wrong[i] / 0.1;
162 x[1] = OA_wrong[i] / 30;
163 x[2] = Angle_wrong[i] / 30;
164 x[3] = Z_wrong[i] / 100;
165 x[4] = Mom1_wrong[i] / 5;
166 x[5] = Mom2_wrong[i] / 5;
167
168 if (x[0] > 1.0) x[0] = 1.0;
169 if (x[1] > 1.0) x[1] = 1.0;
170 if (x[2] > 1.0) x[2] = 1.0;
171 if (x[3] > 1.0) x[3] = 1.0;
172 if (x[4] > 1.0) x[4] = 1.0;
173 if (x[5] > 1.0) x[5] = 1.0;
174
175 xOut = -1.;
176 simu->Fill();
177 if (i >= fMaxNofTrainSamples) break;
178 }
179
180 TMultiLayerPerceptron network("x0,x1,x2,x3,x4,x5:12:xOut", simu, "Entry$+1");
181 network.Train(300, "text,update=10");
182 network.DumpWeights("../../../analysis/conversion2/KresAnalysis_ann_weights.txt");
183
184
185 Double_t params[6];
188
189 for (size_t i = 0; i < IM_correct.size(); i++) {
190 params[0] = IM_correct[i] / 0.1;
191 params[1] = OA_correct[i] / 30;
192 params[2] = Angle_correct[i] / 30;
193 params[3] = Z_correct[i] / 100;
194 params[4] = Mom1_correct[i] / 5;
195 params[5] = Mom2_correct[i] / 5;
196
197 if (params[0] > 1.0) params[0] = 1.0;
198 if (params[1] > 1.0) params[1] = 1.0;
199 if (params[2] > 1.0) params[2] = 1.0;
200 if (params[3] > 1.0) params[3] = 1.0;
201 if (params[4] > 1.0) params[4] = 1.0;
202 if (params[5] > 1.0) params[5] = 1.0;
203
204 Double_t netEval = network.Evaluate(0, params);
205 fhAnnOutput_correct->Fill(netEval);
206 if (netEval < fAnnCut) fNofCorrectLikeWrong++;
207 }
208 for (size_t i = 0; i < IM_wrong.size(); i++) {
209 params[0] = IM_wrong[i] / 0.1;
210 params[1] = OA_wrong[i] / 30;
211 params[2] = Angle_wrong[i] / 30;
212 params[3] = Z_wrong[i] / 100;
213 params[4] = Mom1_wrong[i] / 5;
214 params[5] = Mom2_wrong[i] / 5;
215
216 if (params[0] > 1.0) params[0] = 1.0;
217 if (params[1] > 1.0) params[1] = 1.0;
218 if (params[2] > 1.0) params[2] = 1.0;
219 if (params[3] > 1.0) params[3] = 1.0;
220 if (params[4] > 1.0) params[4] = 1.0;
221 if (params[5] > 1.0) params[5] = 1.0;
222
223 Double_t netEval = network.Evaluate(0, params);
224 fhAnnOutput_wrong->Fill(netEval);
225 if (netEval >= fAnnCut) fNofWrongLikeCorrect++;
226 }
227}
228
229
231{
232 cout << "nof correct pairs = " << IM_correct.size() << endl;
233 cout << "nof wrong pairs = " << IM_wrong.size() << endl;
234 cout << "wrong like correct = " << fNofWrongLikeCorrect
235 << ", wrong supp = " << (Double_t) IM_wrong.size() / fNofWrongLikeCorrect << endl;
236 cout << "Correct like wrong = " << fNofCorrectLikeWrong
237 << ", correct lost eff = " << 100. * (Double_t) fNofCorrectLikeWrong / IM_correct.size() << endl;
238
239 Double_t cumProbFake = 0.;
240 Double_t cumProbTrue = 0.;
241 Int_t nofTrue = (Int_t) fhAnnOutput_correct->GetEntries();
242 Int_t nofFake = (Int_t) fhAnnOutput_wrong->GetEntries();
243
244 for (Int_t i = 1; i <= fhAnnOutput_wrong->GetNbinsX(); i++) {
245 cumProbTrue += fhAnnOutput_correct->GetBinContent(i);
246 fhCumProb_correct->SetBinContent(i, 1. - (Double_t) cumProbTrue / nofTrue);
247
248 cumProbFake += fhAnnOutput_wrong->GetBinContent(i);
249 fhCumProb_wrong->SetBinContent(i, (Double_t) cumProbFake / nofFake);
250 }
251
252
253 TCanvas* c1 = new TCanvas("ann_correct_ann_output", "ann_correct_ann_output", 400, 400);
254 c1->SetTitle("ann_correct_ann_output");
255 fhAnnOutput_correct->Draw();
256
257 TCanvas* c2 = new TCanvas("ann_wrong_ann_output", "ann_wrong_ann_output", 400, 400);
258 c2->SetTitle("ann_wrong_ann_output");
259 fhAnnOutput_wrong->Draw();
260
261 TCanvas* c3 = new TCanvas("ann_correct_cum_prob", "ann_correct_cum_prob", 400, 400);
262 c3->SetTitle("ann_correct_cum_prob");
263 fhCumProb_correct->Draw();
264
265 TCanvas* c4 = new TCanvas("ann_wrong_cum_prob", "ann_wrong_cum_prob", 400, 400);
266 c4->SetTitle("ann_wrong_cum_prob");
267 fhCumProb_wrong->Draw();
268}
269
270
272{
273
274 fhAnnOutput_correct = new TH1D("fhAnnOutput_correct", "ANN output;ANN output;Counter", 100, -1.2, 1.2);
275 fHists.push_back(fhAnnOutput_correct);
276 fhAnnOutput_wrong = new TH1D("fhAnnOutput_wrong", "ANN output;ANN output;Counter", 100, -1.2, 1.2);
277 fHists.push_back(fhAnnOutput_wrong);
278
279 fhCumProb_correct = new TH1D("fhCumProb_correct", "ANN output;ANN output;Cumulative probability", 100, -1.2, 1.2);
280 fHists.push_back(fhCumProb_correct);
281 fhCumProb_wrong = new TH1D("fhCumProb_wrong", "ANN output;ANN output;Cumulative probability", 100, -1.2, 1.2);
282 fHists.push_back(fhCumProb_wrong);
283}
Helper functions for drawing 1D and 2D histograms and graphs.
vector< double > IM_wrong
vector< double > Mom2_correct
virtual ~CbmKresTrainAnn()
vector< double > OA_wrong
unsigned int fMaxNofTrainSamples
vector< double > Z_wrong
vector< double > Mom1_correct
vector< double > Angle_wrong
vector< double > Mom1_wrong
vector< double > OA_correct
vector< double > Mom2_wrong
vector< double > Angle_correct
vector< TH1 * > fHists
void Exec(int event, int IdForANN, double InvariantMass, double OpeningAngle, double PlaneAngle_last, double ZPos, TVector3 Momentum1, TVector3 Momentum2)
vector< double > Z_correct
vector< double > IM_correct
Hash for CbmL1LinkKey.