17#include <TClonesArray.h>
18#include <TStopwatch.h>
24#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
26void CbmRichMCbmDenoiseCnn::Init()
28 fOrtEnv = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, GetName());
30 fOrtSessionOptions = std::make_unique<Ort::SessionOptions>();
33 fOrtSessionOptions->SetIntraOpNumThreads(1);
34 fOrtSessionOptions->SetInterOpNumThreads(1);
35 fOrtSessionOptions->SetExecutionMode(ORT_SEQUENTIAL);
36 fOrtSessionOptions->SetGraphOptimizationLevel(ORT_ENABLE_ALL);
38 fOrtSession = std::make_unique<Ort::Session>(*fOrtEnv, fOnnxFilePath.c_str(), *fOrtSessionOptions);
40 fOrtRunOptions = std::make_unique<Ort::RunOptions>(
nullptr);
41 fOrtAllocatorInfo = std::make_unique<Ort::MemoryInfo>(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU));
42 fOrtInput = std::make_unique<Ort::Value>(Ort::Value::CreateTensor<float>(
43 *fOrtAllocatorInfo, fInput.data(), fInput.size(), fInputShape.data(), fInputShape.size()));
48void CbmRichMCbmDenoiseCnn::Process(
CbmEvent* event,
const TClonesArray* richHits)
51 std::vector<int> timeWindowHitIndices;
52 for (
int iHit = 0; iHit < nHits; iHit++) {
53 timeWindowHitIndices.clear();
56 if (!seedHit)
continue;
57 const auto gridIdxSeed = AddressToGridIndex(seedHit->
GetAddress());
58 if (gridIdxSeed < 0 || gridIdxSeed >=
static_cast<int>(fInput.size())) {
59 LOG(error) << GetName() <<
"::Process: Invalid grid index for seed hit: " << gridIdxSeed
60 <<
". Skipping seed hit.";
63 timeWindowHitIndices.push_back(seedIndex);
64 int hitsInTimeWindow = 1;
65 for (
int jHit = iHit + 1; jHit < nHits; jHit++) {
69 const auto gridIdxHit = AddressToGridIndex(hit->
GetAddress());
70 if (gridIdxHit < 0 || gridIdxHit >=
static_cast<int>(fInput.size())) {
71 LOG(error) << GetName() <<
"::Process: Invalid grid index for hit: " << gridIdxHit <<
". Skipping hit.";
75 if (dt < fTimeWindowLength) {
77 timeWindowHitIndices.push_back(jHitIndex);
84 if (hitsInTimeWindow >= fMinHitsInTimeWindow) {
85 ProcessTimeWindow(timeWindowHitIndices, seedHit->
GetTime(), richHits);
86 iHit += hitsInTimeWindow - 1;
94void CbmRichMCbmDenoiseCnn::ProcessTimeWindow(
const std::vector<int>& timeWindowHitIndices,
const double& seedHitTime,
95 const TClonesArray* richHits)
98 std::vector<int> gridIndices;
99 gridIndices.reserve(50);
100 for (
const auto& hitIndex : timeWindowHitIndices) {
102 const auto gridIdx = AddressToGridIndex(hit->
GetAddress());
103 if (gridIdx < 0 || gridIdx >=
static_cast<int>(fInput.size())) {
106 LOG(fatal) << GetName() <<
"::ProcessTimeWindow: Invalid grid index: " << gridIdx;
109 gridIndices.push_back(gridIdx);
111 fInput[gridIdx] =
static_cast<float>((hit->
GetTime() - seedHitTime + 1.0) / (fTimeWindowLength + 1.0));
114 const auto output = Inference(gridIndices);
116 for (std::size_t i = 0; i < timeWindowHitIndices.size(); i++) {
118 bool isNoise = output[i] < fClassificationThreshold;
123std::vector<float> CbmRichMCbmDenoiseCnn::Inference(
const std::vector<int>& gridIndices)
125 auto output = fOrtSession->Run(*fOrtRunOptions.get(), fInputNames.data(), fOrtInput.get(), 1, fOutputNames.data(), 1);
126 float* intarr = output.front().GetTensorMutableData<
float>();
128 std::vector<float> out(gridIndices.size());
129 std::transform(gridIndices.begin(), gridIndices.end(), out.begin(),
130 [intarr](
int gridIdx) { return intarr[gridIdx]; });
134int CbmRichMCbmDenoiseCnn::AddressToGridIndex(
int address)
136 CbmRichPixelData* pixel_data = fCbmRichGeoHandler->GetPixelDataByAddress(address);
139 int pmtUID = pixel_data->
fPmtId;
140 int pmtIndX = (pmtUID >> 4) & 0xF;
141 int pmtIndY = pmtUID & 0xF;
144 int globalIndX = 8 * pmtIndX + pixel_data->
fPixelId % 8;
145 int globalIndY = 8 * pmtIndY + pixel_data->
fPixelId / 8;
147 return 32 * globalIndY + globalIndX;
Class characterising one event by a collection of links (indices) to data objects,...
int32_t GetAddress() const
static CbmRichGeoHandler & GetInstance()
void SetIsNoiseNN(bool isNoiseNN)