CbmRoot
Loading...
Searching...
No Matches
CbmRichMCbmDenoiseCnn.cxx
Go to the documentation of this file.
1/* Copyright (C) 2024-2025 UGiessen, Giessen
2 SPDX-License-Identifier: GPL-3.0-only
3 Authors: Martin Beyer [committer] */
4
5#if HAVE_ONNXRUNTIME
6
8
9#include "CbmDigiManager.h"
10#include "CbmEvent.h"
11#include "CbmRichDetectorData.h"
12#include "CbmRichGeoHandler.h"
13#include "CbmRichHit.h"
14
15#include <Logger.h>
16
17#include <TClonesArray.h>
18#include <TStopwatch.h>
19
20#include <iostream>
21#include <sstream>
22#include <vector>
23
24#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
25
26void CbmRichMCbmDenoiseCnn::Init()
27{
28 fOrtEnv = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, GetName());
29
30 fOrtSessionOptions = std::make_unique<Ort::SessionOptions>();
31 // Thread numbers need to be explicitly set. With the current Ort version it can't detect
32 // the numbers on the cluster, should be fixed in a later version.
33 fOrtSessionOptions->SetIntraOpNumThreads(1);
34 fOrtSessionOptions->SetInterOpNumThreads(1);
35 fOrtSessionOptions->SetExecutionMode(ORT_SEQUENTIAL);
36 fOrtSessionOptions->SetGraphOptimizationLevel(ORT_ENABLE_ALL);
37
38 fOrtSession = std::make_unique<Ort::Session>(*fOrtEnv, fOnnxFilePath.c_str(), *fOrtSessionOptions);
39
40 fOrtRunOptions = std::make_unique<Ort::RunOptions>(nullptr);
41 fOrtAllocatorInfo = std::make_unique<Ort::MemoryInfo>(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU));
42 fOrtInput = std::make_unique<Ort::Value>(Ort::Value::CreateTensor<float>(
43 *fOrtAllocatorInfo, fInput.data(), fInput.size(), fInputShape.data(), fInputShape.size()));
44
45 fCbmRichGeoHandler = &CbmRichGeoHandler::GetInstance();
46}
47
48void CbmRichMCbmDenoiseCnn::Process(CbmEvent* event, const TClonesArray* richHits)
49{
50 int nHits = event ? event->GetNofData(ECbmDataType::kRichHit) : richHits->GetEntriesFast();
51 std::vector<int> timeWindowHitIndices;
52 for (int iHit = 0; iHit < nHits; iHit++) { // Sliding time window loop
53 timeWindowHitIndices.clear();
54 int seedIndex = event ? event->GetIndex(ECbmDataType::kRichHit, iHit) : iHit;
55 CbmRichHit* seedHit = static_cast<CbmRichHit*>(richHits->At(seedIndex));
56 if (!seedHit) continue;
57 const auto gridIdxSeed = AddressToGridIndex(seedHit->GetAddress());
58 if (gridIdxSeed < 0 || gridIdxSeed >= static_cast<int>(fInput.size())) {
59 LOG(error) << GetName() << "::Process: Invalid grid index for seed hit: " << gridIdxSeed
60 << ". Skipping seed hit.";
61 continue;
62 }
63 timeWindowHitIndices.push_back(seedIndex);
64 int hitsInTimeWindow = 1;
65 for (int jHit = iHit + 1; jHit < nHits; jHit++) { // Search for hits in time window
66 int jHitIndex = event ? event->GetIndex(ECbmDataType::kRichHit, jHit) : jHit;
67 CbmRichHit* hit = static_cast<CbmRichHit*>(richHits->At(jHitIndex));
68 if (!hit) continue;
69 const auto gridIdxHit = AddressToGridIndex(hit->GetAddress());
70 if (gridIdxHit < 0 || gridIdxHit >= static_cast<int>(fInput.size())) {
71 LOG(error) << GetName() << "::Process: Invalid grid index for hit: " << gridIdxHit << ". Skipping hit.";
72 continue;
73 }
74 double dt = hit->GetTime() - seedHit->GetTime();
75 if (dt < fTimeWindowLength) {
76 hitsInTimeWindow++;
77 timeWindowHitIndices.push_back(jHitIndex);
78 }
79 else {
80 break;
81 }
82 }
83
84 if (hitsInTimeWindow >= fMinHitsInTimeWindow) {
85 ProcessTimeWindow(timeWindowHitIndices, seedHit->GetTime(), richHits);
86 iHit += hitsInTimeWindow - 1; // Move to last hit inside time window
87 }
88 else {
89 seedHit->SetIsNoiseNN(true);
90 }
91 }
92}
93
94void CbmRichMCbmDenoiseCnn::ProcessTimeWindow(const std::vector<int>& timeWindowHitIndices, const double& seedHitTime,
95 const TClonesArray* richHits)
96{
97 fInput = {}; // Reset all input values to 0.0
98 std::vector<int> gridIndices;
99 gridIndices.reserve(50);
100 for (const auto& hitIndex : timeWindowHitIndices) {
101 CbmRichHit* hit = static_cast<CbmRichHit*>(richHits->At(hitIndex));
102 const auto gridIdx = AddressToGridIndex(hit->GetAddress());
103 if (gridIdx < 0 || gridIdx >= static_cast<int>(fInput.size())) {
104 // This should never happen since those hits are already filtered.
105 // Keep it included for safety, throwing fatal.
106 LOG(fatal) << GetName() << "::ProcessTimeWindow: Invalid grid index: " << gridIdx;
107 continue;
108 }
109 gridIndices.push_back(gridIdx);
110 // Shift time by 1ns to distinguish empty pixels from seed hit time
111 fInput[gridIdx] = static_cast<float>((hit->GetTime() - seedHitTime + 1.0) / (fTimeWindowLength + 1.0));
112 }
113
114 const auto output = Inference(gridIndices);
115
116 for (std::size_t i = 0; i < timeWindowHitIndices.size(); i++) {
117 CbmRichHit* hit = static_cast<CbmRichHit*>(richHits->At(timeWindowHitIndices[i]));
118 bool isNoise = output[i] < fClassificationThreshold;
119 hit->SetIsNoiseNN(isNoise);
120 }
121}
122
123std::vector<float> CbmRichMCbmDenoiseCnn::Inference(const std::vector<int>& gridIndices)
124{
125 auto output = fOrtSession->Run(*fOrtRunOptions.get(), fInputNames.data(), fOrtInput.get(), 1, fOutputNames.data(), 1);
126 float* intarr = output.front().GetTensorMutableData<float>();
127
128 std::vector<float> out(gridIndices.size());
129 std::transform(gridIndices.begin(), gridIndices.end(), out.begin(),
130 [intarr](int gridIdx) { return intarr[gridIdx]; });
131 return out;
132}
133
134int CbmRichMCbmDenoiseCnn::AddressToGridIndex(int address)
135{
136 CbmRichPixelData* pixel_data = fCbmRichGeoHandler->GetPixelDataByAddress(address);
137
138 // Calculate local X [0,7],Y [0,7] indices of one MAPMT
139 int pmtUID = pixel_data->fPmtId;
140 int pmtIndX = (pmtUID >> 4) & 0xF; // index ascending from -x to x
141 int pmtIndY = pmtUID & 0xF; // index ascending from y to -y
142
143 // Calculate global X [0,31],Y [0,72] indices
144 int globalIndX = 8 * pmtIndX + pixel_data->fPixelId % 8;
145 int globalIndY = 8 * pmtIndY + pixel_data->fPixelId / 8;
146
147 return 32 * globalIndY + globalIndX;
148}
149
150#endif // HAVE_ONNXRUNTIME
Class characterising one event by a collection of links (indices) to data objects,...
Definition CbmEvent.h:34
double GetTime() const
Definition CbmHit.h:79
int32_t GetAddress() const
Definition CbmHit.h:77
static CbmRichGeoHandler & GetInstance()
void SetIsNoiseNN(bool isNoiseNN)
Definition CbmRichHit.h:61