Skip to content

Commit dcdc457

Browse files
vidhyavmeta-codesync[bot]
authored andcommitted
Added throughput and error metrics to endpoint calls. (#1816)
Summary: Pull Request resolved: #1816 As stated above. Reviewed By: pzhan9 Differential Revision: D86246819
1 parent 91b7014 commit dcdc457

File tree

2 files changed

+182
-54
lines changed

2 files changed

+182
-54
lines changed

python/monarch/_src/actor/endpoint.py

Lines changed: 94 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -34,70 +34,66 @@
3434
from monarch._rust_bindings.monarch_hyperactor.shape import Extent
3535

3636
from monarch._src.actor.future import Future
37-
from monarch._src.actor.telemetry import METER
38-
from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call
39-
40-
from opentelemetry.metrics import Histogram
41-
42-
# Histogram for measuring endpoint call latency
43-
endpoint_call_latency_histogram: Histogram = METER.create_histogram(
44-
name="endpoint_call_latency.us",
45-
description="Latency of endpoint call operations in microseconds",
46-
)
47-
48-
# Histogram for measuring endpoint call_one latency
49-
endpoint_call_one_latency_histogram: Histogram = METER.create_histogram(
50-
name="endpoint_call_one_latency.us",
51-
description="Latency of endpoint call_one operations in microseconds",
37+
from monarch._src.actor.metrics import (
38+
endpoint_broadcast_error_counter,
39+
endpoint_broadcast_throughput_counter,
40+
endpoint_call_error_counter,
41+
endpoint_call_latency_histogram,
42+
endpoint_call_one_error_counter,
43+
endpoint_call_one_latency_histogram,
44+
endpoint_call_one_throughput_counter,
45+
endpoint_call_throughput_counter,
46+
endpoint_choose_error_counter,
47+
endpoint_choose_latency_histogram,
48+
endpoint_choose_throughput_counter,
49+
endpoint_stream_latency_histogram,
50+
endpoint_stream_throughput_counter,
5251
)
52+
from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call
5353

54-
# Histogram for measuring endpoint stream latency per yield
55-
endpoint_stream_latency_histogram: Histogram = METER.create_histogram(
56-
name="endpoint_stream_latency.us",
57-
description="Latency of endpoint stream operations per yield in microseconds",
58-
)
59-
60-
# Histogram for measuring endpoint choose latency
61-
endpoint_choose_latency_histogram: Histogram = METER.create_histogram(
62-
name="endpoint_choose_latency.us",
63-
description="Latency of endpoint choose operations in microseconds",
64-
)
54+
from opentelemetry.metrics import Counter, Histogram
6555

6656
T = TypeVar("T")
6757

6858

69-
def _measure_latency(
59+
def _observe_latency_and_error(
7060
coro: Coroutine[Any, Any, T],
7161
start_time_ns: int,
7262
histogram: Histogram,
63+
error_counter: Counter,
7364
method_name: str,
7465
actor_count: int,
7566
) -> Coroutine[Any, Any, T]:
7667
"""
77-
Decorator to measure and record latency of an async operation.
68+
Observe and record latency and errors of an async operation.
7869
7970
Args:
80-
coro: The coroutine to measure
81-
histogram: The histogram to record metrics to
71+
coro: The coroutine to observe
72+
histogram: The histogram to record latency metrics to
73+
error_counter: The counter to record error metrics to
8274
method_name: Name of the method being called
8375
actor_count: Number of actors involved in the call
8476
8577
Returns:
86-
A wrapped coroutine that records latency metrics
78+
A wrapped coroutine that records error and latency metrics
8779
"""
8880

8981
async def _wrapper() -> T:
82+
error_occurred = False
9083
try:
9184
return await coro
85+
except Exception:
86+
error_occurred = True
87+
raise
9288
finally:
9389
duration_us = int((time.monotonic_ns() - start_time_ns) / 1_000)
94-
histogram.record(
95-
duration_us,
96-
attributes={
97-
"method": method_name,
98-
"actor_count": actor_count,
99-
},
100-
)
90+
attributes = {
91+
"method": method_name,
92+
"actor_count": actor_count,
93+
}
94+
histogram.record(duration_us, attributes=attributes)
95+
if error_occurred:
96+
error_counter.add(1, attributes=attributes)
10197

10298
return _wrapper()
10399

@@ -136,27 +132,37 @@ def _get_method_name(self) -> str:
136132
return method_specifier.name
137133
return "unknown"
138134

139-
def _with_latency_measurement(
140-
self, start_time_ns: int, histogram: Histogram, actor_count: int
135+
def _with_telemetry(
136+
self,
137+
start_time_ns: int,
138+
histogram: Histogram,
139+
error_counter: Counter,
140+
actor_count: int,
141141
) -> Any:
142142
"""
143-
Decorator factory to add latency measurement to async functions.
143+
Decorator factory to add telemetry (latency and error tracking) to async functions.
144144
145145
Args:
146-
histogram: The histogram to record metrics to
146+
histogram: The histogram to record latency metrics to
147+
error_counter: The counter to record error metrics to
147148
actor_count: Number of actors involved in the operation
148149
149150
Returns:
150-
A decorator that wraps async functions with latency measurement
151+
A decorator that wraps async functions with telemetry measurement
151152
"""
152153
method_name: str = self._get_method_name()
153154

154155
def decorator(func: Any) -> Any:
155156
@functools.wraps(func)
156157
def wrapper(*args: Any, **kwargs: Any) -> Any:
157158
coro = func(*args, **kwargs)
158-
return _measure_latency(
159-
coro, start_time_ns, histogram, method_name, actor_count
159+
return _observe_latency_and_error(
160+
coro,
161+
start_time_ns,
162+
histogram,
163+
error_counter,
164+
method_name,
165+
actor_count,
160166
)
161167

162168
return wrapper
@@ -210,14 +216,17 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
210216
211217
Load balanced RPC-style entrypoint for request/response messaging.
212218
"""
219+
# Track throughput at method entry
220+
method_name: str = self._get_method_name()
221+
endpoint_choose_throughput_counter.add(1, attributes={"method": method_name})
213222

214223
p, r_port = self._port(once=True)
215224
r: "PortReceiver[R]" = r_port
216225
start_time: int = time.monotonic_ns()
217226
# pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any]
218227
self._send(args, kwargs, port=p, selection="choose")
219228

220-
@self._with_latency_measurement(
229+
self._with_latency_measurement(
221230
start_time, endpoint_choose_latency_histogram, 1
222231
)
223232
async def process() -> R:
@@ -227,6 +236,10 @@ async def process() -> R:
227236
return Future(coro=process())
228237

229238
def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
239+
# Track throughput at method entry
240+
method_name: str = self._get_method_name()
241+
endpoint_call_one_throughput_counter.add(1, attributes={"method": method_name})
242+
230243
p, r_port = self._port(once=True)
231244
r: PortReceiver[R] = r_port
232245
start_time: int = time.monotonic_ns()
@@ -237,8 +250,11 @@ def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
237250
f"Can only use 'call_one' on a single Actor but this actor has shape {extent}"
238251
)
239252

240-
@self._with_latency_measurement(
241-
start_time, endpoint_call_one_latency_histogram, 1
253+
@self._with_telemetry(
254+
start_time,
255+
endpoint_call_one_latency_histogram,
256+
endpoint_call_one_error_counter,
257+
1,
242258
)
243259
async def process() -> R:
244260
result = await r.recv()
@@ -250,13 +266,19 @@ def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
250266
from monarch._src.actor.actor_mesh import RankedPortReceiver, ValueMesh
251267

252268
start_time: int = time.monotonic_ns()
269+
# Track throughput at method entry
270+
method_name: str = self._get_method_name()
271+
endpoint_call_throughput_counter.add(1, attributes={"method": method_name})
253272
p, unranked = self._port()
254273
r: RankedPortReceiver[R] = unranked.ranked()
255274
# pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any]
256275
extent: Extent = self._send(args, kwargs, port=p)
257276

258-
@self._with_latency_measurement(
259-
start_time, endpoint_call_latency_histogram, extent.nelements
277+
@self._with_telemetry(
278+
start_time,
279+
endpoint_call_latency_histogram,
280+
endpoint_call_error_counter,
281+
extent.nelements,
260282
)
261283
async def process() -> "ValueMesh[R]":
262284
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
@@ -283,14 +305,22 @@ def stream(
283305
This enables processing results from multiple actors incrementally as
284306
they become available. Returns an async generator of response values.
285307
"""
308+
# Track throughput at method entry
309+
method_name: str = self._get_method_name()
310+
endpoint_stream_throughput_counter.add(1, attributes={"method": method_name})
311+
286312
p, r_port = self._port()
287313
start_time: int = time.monotonic_ns()
288314
# pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any]
289315
extent: Extent = self._send(args, kwargs, port=p)
290316
r: "PortReceiver[R]" = r_port
291317

292-
latency_decorator: Any = self._with_latency_measurement(
293-
start_time, endpoint_stream_latency_histogram, extent.nelements
318+
# Note: stream doesn't track errors per-yield since errors propagate to caller
319+
latency_decorator: Any = self._with_telemetry(
320+
start_time,
321+
endpoint_stream_latency_histogram,
322+
endpoint_broadcast_error_counter, # Placeholder, errors not tracked per-yield
323+
extent.nelements,
294324
)
295325

296326
def _stream() -> Generator[Future[R], None, None]:
@@ -314,8 +344,18 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
314344
"""
315345
from monarch._src.actor.actor_mesh import send
316346

317-
# pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any]
318-
send(self, args, kwargs)
347+
method_name: str = self._get_method_name()
348+
attributes = {
349+
"method": method_name,
350+
"actor_count": 0, # broadcast doesn't track specific count
351+
}
352+
try:
353+
# pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any]
354+
send(self, args, kwargs)
355+
endpoint_broadcast_throughput_counter.add(1, attributes=attributes)
356+
except Exception:
357+
endpoint_broadcast_error_counter.add(1, attributes=attributes)
358+
raise
319359

320360
@abstractmethod
321361
def _rref(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> R: ...
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
"""
10+
Telemetry metrics for endpoint operations.
11+
12+
This module defines all histograms and counters used to track endpoint
13+
performance metrics including latency, errors, and throughput.
14+
"""
15+
16+
from monarch._src.actor.telemetry import METER
17+
from opentelemetry.metrics import Counter, Histogram
18+
19+
# Histogram for measuring endpoint call latency
20+
endpoint_call_latency_histogram: Histogram = METER.create_histogram(
21+
name="endpoint_call_latency.us",
22+
description="Latency of endpoint call operations in microseconds",
23+
)
24+
25+
# Histogram for measuring endpoint call_one latency
26+
endpoint_call_one_latency_histogram: Histogram = METER.create_histogram(
27+
name="endpoint_call_one_latency.us",
28+
description="Latency of endpoint call_one operations in microseconds",
29+
)
30+
31+
# Histogram for measuring endpoint stream latency per yield
32+
endpoint_stream_latency_histogram: Histogram = METER.create_histogram(
33+
name="endpoint_stream_latency.us",
34+
description="Latency of endpoint stream operations per yield in microseconds",
35+
)
36+
37+
# Histogram for measuring endpoint choose latency
38+
endpoint_choose_latency_histogram: Histogram = METER.create_histogram(
39+
name="endpoint_choose_latency.us",
40+
description="Latency of endpoint choose operations in microseconds",
41+
)
42+
43+
# Counters for measuring endpoint errors
44+
endpoint_call_error_counter: Counter = METER.create_counter(
45+
name="endpoint_call_error.count",
46+
description="Count of errors in endpoint call operations",
47+
)
48+
49+
endpoint_call_one_error_counter: Counter = METER.create_counter(
50+
name="endpoint_call_one_error.count",
51+
description="Count of errors in endpoint call_one operations",
52+
)
53+
54+
endpoint_choose_error_counter: Counter = METER.create_counter(
55+
name="endpoint_choose_error.count",
56+
description="Count of errors in endpoint choose operations",
57+
)
58+
59+
endpoint_broadcast_error_counter: Counter = METER.create_counter(
60+
name="endpoint_broadcast_error.count",
61+
description="Count of errors in endpoint broadcast operations",
62+
)
63+
64+
# Counters for measuring endpoint throughput (call counts)
65+
endpoint_call_throughput_counter: Counter = METER.create_counter(
66+
name="endpoint_call_throughput.count",
67+
description="Count of endpoint call invocations for throughput measurement",
68+
)
69+
70+
endpoint_call_one_throughput_counter: Counter = METER.create_counter(
71+
name="endpoint_call_one_throughput.count",
72+
description="Count of endpoint call_one invocations for throughput measurement",
73+
)
74+
75+
endpoint_choose_throughput_counter: Counter = METER.create_counter(
76+
name="endpoint_choose_throughput.count",
77+
description="Count of endpoint choose invocations for throughput measurement",
78+
)
79+
80+
endpoint_stream_throughput_counter: Counter = METER.create_counter(
81+
name="endpoint_stream_throughput.count",
82+
description="Count of endpoint stream invocations for throughput measurement",
83+
)
84+
85+
endpoint_broadcast_throughput_counter: Counter = METER.create_counter(
86+
name="endpoint_broadcast_throughput.count",
87+
description="Count of endpoint broadcast invocations for throughput measurement",
88+
)

0 commit comments

Comments
 (0)