CbmRoot
Loading...
Searching...
No Matches
CbmRichMCbmDenoiseCnn.cxx
Go to the documentation of this file.
1/* Copyright (C) 2024 UGiessen, Giessen
2 SPDX-License-Identifier: GPL-3.0-only
3 Authors: Martin Beyer [committer] */
4
5#if HAVE_ONNXRUNTIME
6
8
9#include "CbmDigiManager.h"
10#include "CbmEvent.h"
11#include "CbmRichDetectorData.h"
13#include "CbmRichHit.h"
14
15#include <Logger.h>
16
17#include <TClonesArray.h>
18#include <TStopwatch.h>
19
20#include <iostream>
21#include <sstream>
22#include <vector>
23
24#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
25
26void CbmRichMCbmDenoiseCnn::Init()
27{
28 fOrtEnv = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, GetName());
29
30 fOrtSessionOptions = std::make_unique<Ort::SessionOptions>();
31 // Thread numbers need to be explicitly set. With the current Ort version it can't detect
32 // the numbers on the cluster, should be fixed in a later version.
33 fOrtSessionOptions->SetIntraOpNumThreads(1);
34 fOrtSessionOptions->SetInterOpNumThreads(1);
35 fOrtSessionOptions->SetExecutionMode(ORT_SEQUENTIAL);
36 fOrtSessionOptions->SetGraphOptimizationLevel(ORT_ENABLE_ALL);
37
38 fOrtSession = std::make_unique<Ort::Session>(*fOrtEnv, fOnnxFilePath.c_str(), *fOrtSessionOptions);
39
40 fOrtRunOptions = std::make_unique<Ort::RunOptions>(nullptr);
41 fOrtAllocatorInfo = std::make_unique<Ort::MemoryInfo>(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU));
42 fOrtInput = std::make_unique<Ort::Value>(Ort::Value::CreateTensor<float>(
43 *fOrtAllocatorInfo, fInput.data(), fInput.size(), fInputShape.data(), fInputShape.size()));
44
45 fCbmRichDigiMapManager = &CbmRichDigiMapManager::GetInstance();
46}
47
48void CbmRichMCbmDenoiseCnn::Process(CbmEvent* event, const TClonesArray* richHits)
49{
50 int nHits = event ? static_cast<int>(event->GetNofData(ECbmDataType::kRichHit)) : richHits->GetEntriesFast();
51 std::vector<int> timeWindowHitIndices;
52 for (int i = 0; i < nHits; i++) { // Sliding time window loop
53 timeWindowHitIndices.clear();
54 int seedIdx = event ? static_cast<int>(event->GetIndex(ECbmDataType::kRichHit, static_cast<uint32_t>(i))) : i;
55 CbmRichHit* seedHit = static_cast<CbmRichHit*>(richHits->At(seedIdx));
56 if (!seedHit) continue;
57 timeWindowHitIndices.push_back(seedIdx);
58 int hitsInTimeWindow = 1;
59 for (int j = i + 1; j < nHits; j++) { // Search for hits in time window
60 int hitIdx = event ? static_cast<int>(event->GetIndex(ECbmDataType::kRichHit, j)) : j;
61 CbmRichHit* hit = static_cast<CbmRichHit*>(richHits->At(hitIdx));
62 if (!hit) continue;
63 double dt = hit->GetTime() - seedHit->GetTime();
64 if (dt < fTimeWindowLength) {
65 hitsInTimeWindow++;
66 timeWindowHitIndices.push_back(hitIdx);
67 }
68 else {
69 break;
70 }
71 }
72
73 if (hitsInTimeWindow >= fMinHitsInTimeWindow) {
74 ProcessTimeWindow(timeWindowHitIndices, seedHit->GetTime(), richHits);
75 i += hitsInTimeWindow - 1; // Move to last hit inside time window
76 }
77 else {
78 seedHit->SetIsNoiseNN(true);
79 }
80 }
81}
82
83void CbmRichMCbmDenoiseCnn::ProcessTimeWindow(const std::vector<int>& timeWindowHitIndices, const double& seedHitTime,
84 const TClonesArray* richHits)
85{
86 fInput = {}; // Reset all input values to 0.0
87 std::vector<int> gridIndices;
88 gridIndices.reserve(50);
89 for (const auto& hitIdx : timeWindowHitIndices) {
90 CbmRichHit* hit = static_cast<CbmRichHit*>(richHits->At(hitIdx));
91 const auto gridIdx = AddressToGridIndex(hit->GetAddress());
92 if (gridIdx < 0 || gridIdx >= static_cast<int>(fInput.size())) {
93 LOG(error) << GetName() << "::ProcessTimeWindow: Invalid grid index: " << gridIdx;
94 continue;
95 }
96 gridIndices.push_back(gridIdx);
97 // Shift time by 1ns to distinguish empty pixels from seed hit time
98 fInput[gridIdx] = static_cast<float>((hit->GetTime() - seedHitTime + 1.0) / (fTimeWindowLength + 1.0));
99 }
100
101 const auto output = Inference(gridIndices);
102
103 for (std::size_t i = 0; i < timeWindowHitIndices.size(); i++) {
104 CbmRichHit* hit = static_cast<CbmRichHit*>(richHits->At(timeWindowHitIndices[i]));
105 bool isNoise = output[i] < fClassificationThreshold;
106 hit->SetIsNoiseNN(isNoise);
107 }
108}
109
110std::vector<float> CbmRichMCbmDenoiseCnn::Inference(const std::vector<int>& gridIndices)
111{
112 auto output = fOrtSession->Run(*fOrtRunOptions.get(), fInputNames.data(), fOrtInput.get(), 1, fOutputNames.data(), 1);
113 float* intarr = output.front().GetTensorMutableData<float>();
114
115 std::vector<float> out(gridIndices.size());
116 std::transform(gridIndices.begin(), gridIndices.end(), out.begin(),
117 [intarr](int gridIdx) { return intarr[gridIdx]; });
118 return out;
119}
120
121int CbmRichMCbmDenoiseCnn::AddressToGridIndex(int address)
122{
123 CbmRichPixelData* pixel_data = fCbmRichDigiMapManager->GetPixelDataByAddress(address);
124
125 // Calculate local X [0,7],Y [0,7] indices of one MAPMT
126 int pmtUID = pixel_data->fPmtId;
127 int pmtIndX = (pmtUID >> 4) & 0xF; // index ascending from -x to x
128 int pmtIndY = pmtUID & 0xF; // index ascending from y to -y
129
130 // Calculate global X [0,31],Y [0,72] indices
131 int globalIndX = 8 * pmtIndX + pixel_data->fPixelId % 8;
132 int globalIndY = 8 * pmtIndY + pixel_data->fPixelId / 8;
133
134 return 32 * globalIndY + globalIndX;
135}
136
137#endif // HAVE_ONNXRUNTIME
Class characterising one event by a collection of links (indices) to data objects,...
Definition CbmEvent.h:34
double GetTime() const
Definition CbmHit.h:76
int32_t GetAddress() const
Definition CbmHit.h:74
static CbmRichDigiMapManager & GetInstance()
Return Instance of CbmRichGeoManager.
void SetIsNoiseNN(bool isNoiseNN)
Definition CbmRichHit.h:61