1- //  Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ //  Copyright 2023-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22// 
33//  Redistribution and use in source and binary forms, with or without
44//  modification, are permitted provided that the following conditions
3232
3333namespace  triton  { namespace  backend  { namespace  python  {
3434
35- Metric::Metric (const  std::string& labels, void * metric_family_address)
36-     : labels_(labels), operation_value_(0 ), metric_address_(nullptr ),
37-       metric_family_address_ (metric_family_address), is_cleared_(false )
35+ Metric::Metric (
36+     const  std::string& labels, std::optional<const  std::vector<double >> buckets,
37+     void * metric_family_address)
38+     : labels_(labels), buckets_(buckets), operation_value_(0 ),
39+       metric_address_ (nullptr ), metric_family_address_(metric_family_address),
40+       is_cleared_(false )
3841{
3942#ifdef  TRITON_PB_STUB
4043  SendCreateMetricRequest ();
@@ -62,6 +65,20 @@ Metric::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
6265  custom_metric_shm_ptr_->metric_family_address  = metric_family_address_;
6366  custom_metric_shm_ptr_->metric_address  = metric_address_;
6467
68+   //  Histogram specific case
69+   if  (buckets_.has_value ()) {
70+     auto  buckets_size = buckets_.value ().size () * sizeof (double );
71+     std::unique_ptr<PbMemory> buckets_shm = PbMemory::Create (
72+         shm_pool, TRITONSERVER_MemoryType::TRITONSERVER_MEMORY_CPU, 0 ,
73+         buckets_size, reinterpret_cast <char *>(buckets_.value ().data ()),
74+         false  /*  copy_gpu */  );
75+     custom_metric_shm_ptr_->buckets_shm_handle  = buckets_shm->ShmHandle ();
76+     buckets_shm_ = std::move (buckets_shm);
77+   } else  {
78+     custom_metric_shm_ptr_->buckets_shm_handle  = 0 ;
79+     buckets_shm_ = nullptr ;
80+   }
81+ 
6582  //  Save the references to shared memory.
6683  custom_metric_shm_ = std::move (custom_metric_shm);
6784  labels_shm_ = std::move (labels_shm);
@@ -80,17 +97,40 @@ Metric::LoadFromSharedMemory(
8097  std::unique_ptr<PbString> labels_shm = PbString::LoadFromSharedMemory (
8198      shm_pool, custom_metric_shm_ptr->labels_shm_handle );
8299
83-   return  std::unique_ptr<Metric>(new  Metric (custom_metric_shm, labels_shm));
100+   std::unique_ptr<PbMemory> buckets_shm = nullptr ;
101+   if  (custom_metric_shm_ptr->buckets_shm_handle  != 0 ) {
102+     buckets_shm = PbMemory::LoadFromSharedMemory (
103+         shm_pool, custom_metric_shm_ptr->buckets_shm_handle ,
104+         false  /*  open_cuda_handle */  );
105+   }
106+ 
107+   return  std::unique_ptr<Metric>(
108+       new  Metric (custom_metric_shm, labels_shm, buckets_shm));
84109}
85110
86111Metric::Metric (
87112    AllocatedSharedMemory<MetricShm>& custom_metric_shm,
88-     std::unique_ptr<PbString>& labels_shm)
113+     std::unique_ptr<PbString>& labels_shm,
114+     std::unique_ptr<PbMemory>& buckets_shm)
89115    : custom_metric_shm_(std::move(custom_metric_shm)),
90-       labels_shm_ (std::move(labels_shm))
116+       labels_shm_ (std::move(labels_shm)), buckets_shm_(std::move(buckets_shm)) 
91117{
92118  custom_metric_shm_ptr_ = custom_metric_shm_.data_ .get ();
119+ 
120+   //  FIXME: This constructor is called during each
121+   //  set/increment/observe/get_value call. It only needs the pointers.
93122  labels_ = labels_shm_->String ();
123+   if  (buckets_shm_ != nullptr ) {  //  Histogram
124+     size_t  bucket_size = buckets_shm_->ByteSize () / sizeof (double );
125+     std::vector<double > buckets;
126+     buckets.reserve (bucket_size);
127+     for  (size_t  i = 0 ; i < bucket_size; ++i) {
128+       buckets.emplace_back (
129+           reinterpret_cast <double *>(buckets_shm_->DataPtr ())[i]);
130+     }
131+     buckets_ = std::move (buckets);
132+   }
133+ 
94134  operation_value_ = custom_metric_shm_ptr_->operation_value ;
95135  metric_family_address_ = custom_metric_shm_ptr_->metric_family_address ;
96136  metric_address_ = custom_metric_shm_ptr_->metric_address ;
@@ -161,6 +201,24 @@ Metric::SendSetValueRequest(const double& value)
161201  }
162202}
163203
204+ void 
205+ Metric::SendObserveRequest (const  double & value)
206+ {
207+   try  {
208+     CheckIfCleared ();
209+     std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance ();
210+     operation_value_ = value;
211+     SaveToSharedMemory (stub->ShmPool ());
212+     AllocatedSharedMemory<CustomMetricsMessage> custom_metrics_shm;
213+     stub->SendMessage <CustomMetricsMessage>(
214+         custom_metrics_shm, PYTHONSTUB_MetricRequestObserve, shm_handle_);
215+   }
216+   catch  (const  PythonBackendException& pb_exception) {
217+     throw  PythonBackendException (
218+         " Failed to observe metric value: "   + std::string (pb_exception.what ()));
219+   }
220+ }
221+ 
164222double 
165223Metric::SendGetValueRequest ()
166224{
@@ -222,14 +280,35 @@ Metric::InitializeTritonMetric()
222280{
223281  std::vector<const  TRITONSERVER_Parameter*> labels_params;
224282  ParseLabels (labels_params, labels_);
283+   TRITONSERVER_MetricKind kind;
284+   THROW_IF_TRITON_ERROR (TRITONSERVER_GetMetricFamilyKind (
285+       reinterpret_cast <TRITONSERVER_MetricFamily*>(metric_family_address_),
286+       &kind));
287+   TRITONSERVER_MetricArgs* args = nullptr ;
288+   switch  (kind) {
289+     case  TRITONSERVER_METRIC_KIND_COUNTER:
290+     case  TRITONSERVER_METRIC_KIND_GAUGE:
291+       break ;
292+     case  TRITONSERVER_METRIC_KIND_HISTOGRAM: {
293+       const  std::vector<double >& buckets = buckets_.value ();
294+       THROW_IF_TRITON_ERROR (TRITONSERVER_MetricArgsNew (&args));
295+       THROW_IF_TRITON_ERROR (TRITONSERVER_MetricArgsSetHistogram (
296+           args, buckets.data (), buckets.size ()));
297+       break ;
298+     }
299+     default :
300+       break ;
301+   }
302+ 
225303  TRITONSERVER_Metric* triton_metric = nullptr ;
226-   THROW_IF_TRITON_ERROR (TRITONSERVER_MetricNew (
304+   THROW_IF_TRITON_ERROR (TRITONSERVER_MetricNewWithArgs (
227305      &triton_metric,
228306      reinterpret_cast <TRITONSERVER_MetricFamily*>(metric_family_address_),
229-       labels_params.data (), labels_params.size ()));
307+       labels_params.data (), labels_params.size (), args ));
230308  for  (const  auto  label : labels_params) {
231309    TRITONSERVER_ParameterDelete (const_cast <TRITONSERVER_Parameter*>(label));
232310  }
311+   THROW_IF_TRITON_ERROR (TRITONSERVER_MetricArgsDelete (args));
233312  return  reinterpret_cast <void *>(triton_metric);
234313}
235314
@@ -262,6 +341,8 @@ Metric::HandleMetricOperation(
262341    Increment (operation_value_);
263342  } else  if  (command_type == PYTHONSTUB_MetricRequestSet) {
264343    SetValue (operation_value_);
344+   } else  if  (command_type == PYTHONSTUB_MetricRequestObserve) {
345+     Observe (operation_value_);
265346  } else  {
266347    throw  PythonBackendException (" Unknown metric operation"  );
267348  }
@@ -281,6 +362,13 @@ Metric::SetValue(const double& value)
281362  THROW_IF_TRITON_ERROR (TRITONSERVER_MetricSet (triton_metric, value));
282363}
283364
365+ void 
366+ Metric::Observe (const  double & value)
367+ {
368+   auto  triton_metric = reinterpret_cast <TRITONSERVER_Metric*>(metric_address_);
369+   THROW_IF_TRITON_ERROR (TRITONSERVER_MetricObserve (triton_metric, value));
370+ }
371+ 
284372double 
285373Metric::GetValue ()
286374{
0 commit comments