1- // Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ // Copyright 2023-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22//
33// Redistribution and use in source and binary forms, with or without
44// modification, are permitted provided that the following conditions
3232
3333namespace triton { namespace backend { namespace python {
3434
35- Metric::Metric (const std::string& labels, void * metric_family_address)
36- : labels_(labels), operation_value_(0 ), metric_address_(nullptr ),
37- metric_family_address_ (metric_family_address), is_cleared_(false )
35+ Metric::Metric (
36+ const std::string& labels, std::optional<const std::vector<double >> buckets,
37+ void * metric_family_address)
38+ : labels_(labels), buckets_(buckets), operation_value_(0 ),
39+ metric_address_ (nullptr ), metric_family_address_(metric_family_address),
40+ is_cleared_(false )
3841{
3942#ifdef TRITON_PB_STUB
4043 SendCreateMetricRequest ();
@@ -62,6 +65,20 @@ Metric::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
6265 custom_metric_shm_ptr_->metric_family_address = metric_family_address_;
6366 custom_metric_shm_ptr_->metric_address = metric_address_;
6467
68+ // Histogram specific case
69+ if (buckets_.has_value ()) {
70+ auto buckets_size = buckets_.value ().size () * sizeof (double );
71+ std::unique_ptr<PbMemory> buckets_shm = PbMemory::Create (
72+ shm_pool, TRITONSERVER_MemoryType::TRITONSERVER_MEMORY_CPU, 0 ,
73+ buckets_size, reinterpret_cast <char *>(buckets_.value ().data ()),
74+ false /* copy_gpu */ );
75+ custom_metric_shm_ptr_->buckets_shm_handle = buckets_shm->ShmHandle ();
76+ buckets_shm_ = std::move (buckets_shm);
77+ } else {
78+ custom_metric_shm_ptr_->buckets_shm_handle = 0 ;
79+ buckets_shm_ = nullptr ;
80+ }
81+
6582 // Save the references to shared memory.
6683 custom_metric_shm_ = std::move (custom_metric_shm);
6784 labels_shm_ = std::move (labels_shm);
@@ -80,17 +97,40 @@ Metric::LoadFromSharedMemory(
8097 std::unique_ptr<PbString> labels_shm = PbString::LoadFromSharedMemory (
8198 shm_pool, custom_metric_shm_ptr->labels_shm_handle );
8299
83- return std::unique_ptr<Metric>(new Metric (custom_metric_shm, labels_shm));
100+ std::unique_ptr<PbMemory> buckets_shm = nullptr ;
101+ if (custom_metric_shm_ptr->buckets_shm_handle != 0 ) {
102+ buckets_shm = PbMemory::LoadFromSharedMemory (
103+ shm_pool, custom_metric_shm_ptr->buckets_shm_handle ,
104+ false /* open_cuda_handle */ );
105+ }
106+
107+ return std::unique_ptr<Metric>(
108+ new Metric (custom_metric_shm, labels_shm, buckets_shm));
84109}
85110
86111Metric::Metric (
87112 AllocatedSharedMemory<MetricShm>& custom_metric_shm,
88- std::unique_ptr<PbString>& labels_shm)
113+ std::unique_ptr<PbString>& labels_shm,
114+ std::unique_ptr<PbMemory>& buckets_shm)
89115 : custom_metric_shm_(std::move(custom_metric_shm)),
90- labels_shm_ (std::move(labels_shm))
116+ labels_shm_ (std::move(labels_shm)), buckets_shm_(std::move(buckets_shm))
91117{
92118 custom_metric_shm_ptr_ = custom_metric_shm_.data_ .get ();
119+
120+ // FIXME: This constructor is called during each
121+ // set/increment/observe/get_value call. It only needs the pointers.
93122 labels_ = labels_shm_->String ();
123+ if (buckets_shm_ != nullptr ) { // Histogram
124+ size_t bucket_size = buckets_shm_->ByteSize () / sizeof (double );
125+ std::vector<double > buckets;
126+ buckets.reserve (bucket_size);
127+ for (size_t i = 0 ; i < bucket_size; ++i) {
128+ buckets.emplace_back (
129+ reinterpret_cast <double *>(buckets_shm_->DataPtr ())[i]);
130+ }
131+ buckets_ = std::move (buckets);
132+ }
133+
94134 operation_value_ = custom_metric_shm_ptr_->operation_value ;
95135 metric_family_address_ = custom_metric_shm_ptr_->metric_family_address ;
96136 metric_address_ = custom_metric_shm_ptr_->metric_address ;
@@ -161,6 +201,24 @@ Metric::SendSetValueRequest(const double& value)
161201 }
162202}
163203
204+ void
205+ Metric::SendObserveRequest (const double & value)
206+ {
207+ try {
208+ CheckIfCleared ();
209+ std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance ();
210+ operation_value_ = value;
211+ SaveToSharedMemory (stub->ShmPool ());
212+ AllocatedSharedMemory<CustomMetricsMessage> custom_metrics_shm;
213+ stub->SendMessage <CustomMetricsMessage>(
214+ custom_metrics_shm, PYTHONSTUB_MetricRequestObserve, shm_handle_);
215+ }
216+ catch (const PythonBackendException& pb_exception) {
217+ throw PythonBackendException (
218+ " Failed to observe metric value: " + std::string (pb_exception.what ()));
219+ }
220+ }
221+
164222double
165223Metric::SendGetValueRequest ()
166224{
@@ -222,14 +280,35 @@ Metric::InitializeTritonMetric()
222280{
223281 std::vector<const TRITONSERVER_Parameter*> labels_params;
224282 ParseLabels (labels_params, labels_);
283+ TRITONSERVER_MetricKind kind;
284+ THROW_IF_TRITON_ERROR (TRITONSERVER_GetMetricFamilyKind (
285+ reinterpret_cast <TRITONSERVER_MetricFamily*>(metric_family_address_),
286+ &kind));
287+ TRITONSERVER_MetricArgs* args = nullptr ;
288+ switch (kind) {
289+ case TRITONSERVER_METRIC_KIND_COUNTER:
290+ case TRITONSERVER_METRIC_KIND_GAUGE:
291+ break ;
292+ case TRITONSERVER_METRIC_KIND_HISTOGRAM: {
293+ const std::vector<double >& buckets = buckets_.value ();
294+ THROW_IF_TRITON_ERROR (TRITONSERVER_MetricArgsNew (&args));
295+ THROW_IF_TRITON_ERROR (TRITONSERVER_MetricArgsSetHistogram (
296+ args, buckets.data (), buckets.size ()));
297+ break ;
298+ }
299+ default :
300+ break ;
301+ }
302+
225303 TRITONSERVER_Metric* triton_metric = nullptr ;
226- THROW_IF_TRITON_ERROR (TRITONSERVER_MetricNew (
304+ THROW_IF_TRITON_ERROR (TRITONSERVER_MetricNewWithArgs (
227305 &triton_metric,
228306 reinterpret_cast <TRITONSERVER_MetricFamily*>(metric_family_address_),
229- labels_params.data (), labels_params.size ()));
307+ labels_params.data (), labels_params.size (), args ));
230308 for (const auto label : labels_params) {
231309 TRITONSERVER_ParameterDelete (const_cast <TRITONSERVER_Parameter*>(label));
232310 }
311+ THROW_IF_TRITON_ERROR (TRITONSERVER_MetricArgsDelete (args));
233312 return reinterpret_cast <void *>(triton_metric);
234313}
235314
@@ -262,6 +341,8 @@ Metric::HandleMetricOperation(
262341 Increment (operation_value_);
263342 } else if (command_type == PYTHONSTUB_MetricRequestSet) {
264343 SetValue (operation_value_);
344+ } else if (command_type == PYTHONSTUB_MetricRequestObserve) {
345+ Observe (operation_value_);
265346 } else {
266347 throw PythonBackendException (" Unknown metric operation" );
267348 }
@@ -281,6 +362,13 @@ Metric::SetValue(const double& value)
281362 THROW_IF_TRITON_ERROR (TRITONSERVER_MetricSet (triton_metric, value));
282363}
283364
365+ void
366+ Metric::Observe (const double & value)
367+ {
368+ auto triton_metric = reinterpret_cast <TRITONSERVER_Metric*>(metric_address_);
369+ THROW_IF_TRITON_ERROR (TRITONSERVER_MetricObserve (triton_metric, value));
370+ }
371+
284372double
285373Metric::GetValue ()
286374{
0 commit comments