17#include <TClonesArray.h>
18#include <TStopwatch.h>
24#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
26void CbmRichMCbmDenoiseCnn::Init()
28 fOrtEnv = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, GetName());
30 fOrtSessionOptions = std::make_unique<Ort::SessionOptions>();
33 fOrtSessionOptions->SetIntraOpNumThreads(1);
34 fOrtSessionOptions->SetInterOpNumThreads(1);
35 fOrtSessionOptions->SetExecutionMode(ORT_SEQUENTIAL);
36 fOrtSessionOptions->SetGraphOptimizationLevel(ORT_ENABLE_ALL);
38 fOrtSession = std::make_unique<Ort::Session>(*fOrtEnv, fOnnxFilePath.c_str(), *fOrtSessionOptions);
40 fOrtRunOptions = std::make_unique<Ort::RunOptions>(
nullptr);
41 fOrtAllocatorInfo = std::make_unique<Ort::MemoryInfo>(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU));
42 fOrtInput = std::make_unique<Ort::Value>(Ort::Value::CreateTensor<float>(
43 *fOrtAllocatorInfo, fInput.data(), fInput.size(), fInputShape.data(), fInputShape.size()));
48void CbmRichMCbmDenoiseCnn::Process(
CbmEvent* event,
const TClonesArray* richHits)
50 int nHits =
event ?
static_cast<int>(
event->GetNofData(
ECbmDataType::kRichHit)) : richHits->GetEntriesFast();
51 std::vector<int> timeWindowHitIndices;
52 for (
int i = 0; i < nHits; i++) {
53 timeWindowHitIndices.clear();
54 int seedIdx =
event ?
static_cast<int>(
event->GetIndex(
ECbmDataType::kRichHit,
static_cast<uint32_t
>(i))) : i;
56 if (!seedHit)
continue;
57 timeWindowHitIndices.push_back(seedIdx);
58 int hitsInTimeWindow = 1;
59 for (
int j = i + 1; j < nHits; j++) {
64 if (dt < fTimeWindowLength) {
66 timeWindowHitIndices.push_back(hitIdx);
73 if (hitsInTimeWindow >= fMinHitsInTimeWindow) {
74 ProcessTimeWindow(timeWindowHitIndices, seedHit->
GetTime(), richHits);
75 i += hitsInTimeWindow - 1;
83void CbmRichMCbmDenoiseCnn::ProcessTimeWindow(
const std::vector<int>& timeWindowHitIndices,
const double& seedHitTime,
84 const TClonesArray* richHits)
87 std::vector<int> gridIndices;
88 gridIndices.reserve(50);
89 for (
const auto& hitIdx : timeWindowHitIndices) {
91 const auto gridIdx = AddressToGridIndex(hit->
GetAddress());
92 if (gridIdx < 0 || gridIdx >=
static_cast<int>(fInput.size())) {
93 LOG(error) << GetName() <<
"::ProcessTimeWindow: Invalid grid index: " << gridIdx;
96 gridIndices.push_back(gridIdx);
98 fInput[gridIdx] =
static_cast<float>((hit->
GetTime() - seedHitTime + 1.0) / (fTimeWindowLength + 1.0));
101 const auto output = Inference(gridIndices);
103 for (std::size_t i = 0; i < timeWindowHitIndices.size(); i++) {
105 bool isNoise = output[i] < fClassificationThreshold;
110std::vector<float> CbmRichMCbmDenoiseCnn::Inference(
const std::vector<int>& gridIndices)
112 auto output = fOrtSession->Run(*fOrtRunOptions.get(), fInputNames.data(), fOrtInput.get(), 1, fOutputNames.data(), 1);
113 float* intarr = output.front().GetTensorMutableData<
float>();
115 std::vector<float> out(gridIndices.size());
116 std::transform(gridIndices.begin(), gridIndices.end(), out.begin(),
117 [intarr](
int gridIdx) { return intarr[gridIdx]; });
121int CbmRichMCbmDenoiseCnn::AddressToGridIndex(
int address)
123 CbmRichPixelData* pixel_data = fCbmRichDigiMapManager->GetPixelDataByAddress(address);
126 int pmtUID = pixel_data->
fPmtId;
127 int pmtIndX = (pmtUID >> 4) & 0xF;
128 int pmtIndY = pmtUID & 0xF;
131 int globalIndX = 8 * pmtIndX + pixel_data->
fPixelId % 8;
132 int globalIndY = 8 * pmtIndY + pixel_data->
fPixelId / 8;
134 return 32 * globalIndY + globalIndX;
Class characterising one event by a collection of links (indices) to data objects,...
int32_t GetAddress() const
static CbmRichDigiMapManager & GetInstance()
Return Instance of CbmRichGeoManager.
void SetIsNoiseNN(bool isNoiseNN)