1use std::{
2 convert::Infallible,
3 net::SocketAddr,
4 pin::Pin,
5 sync::{
6 Arc,
7 atomic::{AtomicUsize, Ordering},
8 },
9 task::{Context, Poll},
10 time::Duration,
11};
12
13use futures::{FutureExt, StreamExt, future::BoxFuture};
14use http::{HeaderMap, Request, Response};
15use hyper::{Body, body::HttpBody};
16use pin_project::pin_project;
17use tokio::{
18 io::{AsyncRead, AsyncWrite, ReadBuf},
19 net::TcpStream,
20 time::{Sleep, sleep},
21};
22use tonic::{
23 body::BoxBody,
24 server::NamedService,
25 transport::server::{Connected, Routes, Server},
26};
27use tower::{Layer, Service};
28use tower_http::{
29 classify::{GrpcErrorsAsFailures, SharedClassifier},
30 trace::TraceLayer,
31};
32use tracing::Span;
33
34use crate::{
35 internal_events::{GrpcServerRequestReceived, GrpcServerResponseSent},
36 shutdown::{ShutdownSignal, ShutdownSignalToken},
37 tls::{MaybeTlsIncomingStream, MaybeTlsSettings},
38};
39use vector_lib::configurable::configurable_component;
40
41mod decompression;
42pub use self::decompression::{DecompressionAndMetrics, DecompressionAndMetricsLayer};
43
44#[cfg(test)]
45static MAX_CONNECTION_AGE_CONNECTION_OBSERVATIONS: std::sync::Mutex<Vec<SocketAddr>> =
46 std::sync::Mutex::new(Vec::new());
47
48#[cfg(all(test, feature = "sources-vector", feature = "sinks-vector"))]
49pub(crate) mod test_support {
50 use std::net::SocketAddr;
51
52 use super::MAX_CONNECTION_AGE_CONNECTION_OBSERVATIONS;
53
54 pub(crate) fn reset_max_connection_age_connection_observations() {
55 MAX_CONNECTION_AGE_CONNECTION_OBSERVATIONS
56 .lock()
57 .unwrap()
58 .clear();
59 }
60
61 pub(crate) fn max_connection_age_connection_observations() -> Vec<SocketAddr> {
62 MAX_CONNECTION_AGE_CONNECTION_OBSERVATIONS
63 .lock()
64 .unwrap()
65 .clone()
66 }
67}
68
69#[configurable_component]
71#[derive(Clone, Debug, Default, PartialEq, Eq)]
72#[serde(deny_unknown_fields)]
73pub struct GrpcKeepaliveConfig {
74 #[serde(default)]
78 #[configurable(metadata(docs::examples = 300))]
79 #[configurable(metadata(docs::type_unit = "seconds"))]
80 #[configurable(metadata(docs::human_name = "Maximum Connection Age"))]
81 pub max_connection_age_secs: Option<u64>,
82
83 #[serde(default)]
87 #[configurable(metadata(docs::examples = 30))]
88 #[configurable(metadata(docs::type_unit = "seconds"))]
89 #[configurable(metadata(docs::human_name = "Maximum Connection Age Grace"))]
90 pub max_connection_age_grace_secs: Option<u64>,
91}
92
93impl GrpcKeepaliveConfig {
94 fn max_connection_lifetime(&self) -> Option<Duration> {
95 self.max_connection_age_secs.map(|max_connection_age_secs| {
96 let age = Duration::from_secs(max_connection_age_secs);
97 let grace = self
98 .max_connection_age_grace_secs
99 .map(Duration::from_secs)
100 .unwrap_or_default();
101
102 age.checked_add(grace).unwrap_or(Duration::MAX)
103 })
104 }
105}
106
107struct MaxConnectionAgeIo {
108 inner: MaybeTlsIncomingStream<TcpStream>,
109 state: MaxConnectionAgeState,
110}
111
112impl MaxConnectionAgeIo {
113 fn new(inner: MaybeTlsIncomingStream<TcpStream>, lifetime: Option<Duration>) -> Self {
114 #[cfg(test)]
115 if lifetime.is_some() {
116 MAX_CONNECTION_AGE_CONNECTION_OBSERVATIONS
117 .lock()
118 .unwrap()
119 .push(inner.peer_addr());
120 }
121
122 Self {
123 inner,
124 state: MaxConnectionAgeState::new(lifetime),
125 }
126 }
127}
128
129struct MaxConnectionAgeState {
130 deadline: Option<Pin<Box<Sleep>>>,
131 read_expired: bool,
132 active_requests: Arc<AtomicUsize>,
133}
134
135impl MaxConnectionAgeState {
136 fn new(lifetime: Option<Duration>) -> Self {
137 Self {
138 deadline: lifetime.map(|lifetime| Box::pin(sleep(lifetime))),
139 read_expired: false,
140 active_requests: Arc::new(AtomicUsize::new(0)),
141 }
142 }
143
144 fn is_read_expired(&mut self, cx: &mut Context<'_>) -> bool {
145 if self.read_expired {
146 return true;
147 }
148
149 self.read_expired = self
150 .deadline
151 .as_mut()
152 .is_some_and(|deadline| deadline.as_mut().poll(cx).is_ready());
153
154 self.read_expired
155 }
156
157 fn is_write_expired(&mut self, cx: &mut Context<'_>) -> bool {
158 self.is_read_expired(cx) && self.active_requests.load(Ordering::Acquire) == 0
159 }
160
161 fn active_requests(&self) -> Arc<AtomicUsize> {
162 Arc::clone(&self.active_requests)
163 }
164
165 #[cfg(test)]
166 fn is_read_expired_for_test(&mut self, cx: &mut Context<'_>) -> bool {
167 self.is_read_expired(cx)
168 }
169
170 #[cfg(test)]
171 fn is_write_expired_for_test(&mut self, cx: &mut Context<'_>) -> bool {
172 self.is_write_expired(cx)
173 }
174}
175
176impl AsyncRead for MaxConnectionAgeIo {
177 fn poll_read(
178 self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &mut ReadBuf<'_>,
181 ) -> Poll<std::io::Result<()>> {
182 let this = self.get_mut();
183 if this.state.is_read_expired(cx) {
184 Poll::Ready(Ok(()))
185 } else {
186 Pin::new(&mut this.inner).poll_read(cx, buf)
187 }
188 }
189}
190
191impl AsyncWrite for MaxConnectionAgeIo {
192 fn poll_write(
193 self: Pin<&mut Self>,
194 cx: &mut Context<'_>,
195 buf: &[u8],
196 ) -> Poll<std::io::Result<usize>> {
197 let this = self.get_mut();
198 if this.state.is_write_expired(cx) {
199 Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()))
200 } else {
201 Pin::new(&mut this.inner).poll_write(cx, buf)
202 }
203 }
204
205 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
206 let this = self.get_mut();
207 if this.state.is_write_expired(cx) {
208 Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()))
209 } else {
210 Pin::new(&mut this.inner).poll_flush(cx)
211 }
212 }
213
214 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
215 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
216 }
217}
218
219impl Connected for MaxConnectionAgeIo {
220 type ConnectInfo = MaxConnectionAgeConnectInfo;
221
222 fn connect_info(&self) -> Self::ConnectInfo {
223 MaxConnectionAgeConnectInfo {
224 active_requests: self.state.active_requests(),
225 }
226 }
227}
228
229#[derive(Clone, Debug)]
230struct MaxConnectionAgeConnectInfo {
231 active_requests: Arc<AtomicUsize>,
232}
233
234#[derive(Clone)]
235struct MaxConnectionAgeLayer;
236
237impl MaxConnectionAgeLayer {
238 const fn new() -> Self {
239 Self
240 }
241}
242
243impl<S> Layer<S> for MaxConnectionAgeLayer {
244 type Service = MaxConnectionAgeService<S>;
245
246 fn layer(&self, service: S) -> Self::Service {
247 MaxConnectionAgeService { service }
248 }
249}
250
251#[derive(Clone)]
252struct MaxConnectionAgeService<S> {
253 service: S,
254}
255
256impl<S> NamedService for MaxConnectionAgeService<S>
257where
258 S: NamedService,
259{
260 const NAME: &'static str = S::NAME;
261}
262
263struct ActiveRequestGuard {
264 active_requests: Arc<AtomicUsize>,
265}
266
267impl ActiveRequestGuard {
268 fn new(active_requests: Arc<AtomicUsize>) -> Self {
269 active_requests.fetch_add(1, Ordering::AcqRel);
270 Self { active_requests }
271 }
272}
273
274impl Drop for ActiveRequestGuard {
275 fn drop(&mut self) {
276 self.active_requests.fetch_sub(1, Ordering::AcqRel);
277 }
278}
279
280#[pin_project]
281struct MaxConnectionAgeBody<B> {
282 #[pin]
283 inner: B,
284 _guard: Option<ActiveRequestGuard>,
285}
286
287impl<B> MaxConnectionAgeBody<B> {
288 const fn new(inner: B, guard: Option<ActiveRequestGuard>) -> Self {
289 Self {
290 inner,
291 _guard: guard,
292 }
293 }
294}
295
296impl<B> HttpBody for MaxConnectionAgeBody<B>
297where
298 B: HttpBody,
299{
300 type Data = B::Data;
301 type Error = B::Error;
302
303 fn poll_data(
304 self: Pin<&mut Self>,
305 cx: &mut Context<'_>,
306 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
307 self.project().inner.poll_data(cx)
308 }
309
310 fn poll_trailers(
311 self: Pin<&mut Self>,
312 cx: &mut Context<'_>,
313 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
314 self.project().inner.poll_trailers(cx)
315 }
316
317 fn is_end_stream(&self) -> bool {
318 self.inner.is_end_stream()
319 }
320
321 fn size_hint(&self) -> hyper::body::SizeHint {
322 self.inner.size_hint()
323 }
324}
325
326impl<S, B> Service<Request<Body>> for MaxConnectionAgeService<S>
327where
328 S: Service<Request<Body>, Response = Response<B>> + Clone + Send + 'static,
329 S::Future: Send + 'static,
330 B: HttpBody + Send + 'static,
331{
332 type Response = Response<MaxConnectionAgeBody<B>>;
333 type Error = S::Error;
334 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
335
336 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
337 self.service.poll_ready(cx)
338 }
339
340 fn call(&mut self, req: Request<Body>) -> Self::Future {
341 let guard = req
342 .extensions()
343 .get::<MaxConnectionAgeConnectInfo>()
344 .map(|connect_info| ActiveRequestGuard::new(Arc::clone(&connect_info.active_requests)));
345 let future = self.service.call(req);
346
347 async move {
348 future
349 .await
350 .map(|response| response.map(|body| MaxConnectionAgeBody::new(body, guard)))
351 }
352 .boxed()
353 }
354}
355
356pub async fn run_grpc_server<S>(
357 address: SocketAddr,
358 tls_settings: MaybeTlsSettings,
359 service: S,
360 keepalive: GrpcKeepaliveConfig,
361 shutdown: ShutdownSignal,
362) -> crate::Result<()>
363where
364 S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
365 + NamedService
366 + Clone
367 + Send
368 + 'static,
369 S::Future: Send + 'static,
370{
371 let span = Span::current();
372 let (tx, rx) = tokio::sync::oneshot::channel::<ShutdownSignalToken>();
373 let listener = tls_settings.bind(&address).await?;
374 let max_connection_lifetime = keepalive.max_connection_lifetime();
375 let stream = listener
376 .accept_stream()
377 .map(move |stream| stream.map(|io| MaxConnectionAgeIo::new(io, max_connection_lifetime)));
378
379 info!(%address, "Building gRPC server.");
380
381 Server::builder()
382 .layer(MaxConnectionAgeLayer::new())
383 .layer(build_grpc_trace_layer(span.clone()))
384 .layer(DecompressionAndMetricsLayer)
394 .add_service(service)
395 .serve_with_incoming_shutdown(stream, shutdown.map(|token| tx.send(token).unwrap()))
396 .await?;
397
398 drop(rx.await);
399
400 Ok(())
401}
402
403pub async fn run_grpc_server_with_routes(
406 address: SocketAddr,
407 tls_settings: MaybeTlsSettings,
408 routes: Routes,
409 keepalive: GrpcKeepaliveConfig,
410 shutdown: ShutdownSignal,
411) -> crate::Result<()> {
412 let span = Span::current();
413 let (tx, rx) = tokio::sync::oneshot::channel::<ShutdownSignalToken>();
414 let listener = tls_settings.bind(&address).await?;
415 let max_connection_lifetime = keepalive.max_connection_lifetime();
416 let stream = listener
417 .accept_stream()
418 .map(move |stream| stream.map(|io| MaxConnectionAgeIo::new(io, max_connection_lifetime)));
419
420 info!(%address, "Building gRPC server.");
421
422 Server::builder()
423 .layer(MaxConnectionAgeLayer::new())
424 .layer(build_grpc_trace_layer(span.clone()))
425 .layer(DecompressionAndMetricsLayer)
426 .add_routes(routes)
427 .serve_with_incoming_shutdown(stream, shutdown.map(|token| tx.send(token).unwrap()))
428 .await?;
429
430 drop(rx.await);
431
432 Ok(())
433}
434
435pub fn build_grpc_trace_layer(
439 span: Span,
440) -> TraceLayer<
441 SharedClassifier<GrpcErrorsAsFailures>,
442 impl Fn(&Request<Body>) -> Span + Clone,
443 impl Fn(&Request<Body>, &Span) + Clone,
444 impl Fn(&Response<BoxBody>, Duration, &Span) + Clone,
445 (),
446 (),
447 (),
448> {
449 TraceLayer::new_for_grpc()
450 .make_span_with(move |request: &Request<Body>| {
451 let mut path = request.uri().path().split('/');
453 let service = path.nth(1).unwrap_or("_unknown");
454 let method = path.next().unwrap_or("_unknown");
455
456 error_span!(
458 parent: &span,
459 "grpc-request",
460 grpc_service = service,
461 grpc_method = method,
462 )
463 })
464 .on_request(Box::new(|_request: &Request<Body>, _span: &Span| {
465 emit!(GrpcServerRequestReceived);
466 }))
467 .on_response(
468 |response: &Response<BoxBody>, latency: Duration, _span: &Span| {
469 emit!(GrpcServerResponseSent { response, latency });
470 },
471 )
472 .on_failure(())
473 .on_body_chunk(())
474 .on_eos(())
475}
476
477#[cfg(test)]
478mod tests {
479 use std::future::{Ready, ready};
480
481 use super::*;
482
483 #[derive(Clone)]
484 struct EmptyBodyService;
485
486 impl Service<Request<Body>> for EmptyBodyService {
487 type Response = Response<Body>;
488 type Error = Infallible;
489 type Future = Ready<Result<Self::Response, Self::Error>>;
490
491 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
492 Poll::Ready(Ok(()))
493 }
494
495 fn call(&mut self, _req: Request<Body>) -> Self::Future {
496 ready(Ok(Response::new(Body::empty())))
497 }
498 }
499
500 #[tokio::test]
501 async fn max_connection_age_service_tracks_response_body_until_drop() {
502 let active_requests = Arc::new(AtomicUsize::new(0));
503 let mut service = MaxConnectionAgeService {
504 service: EmptyBodyService,
505 };
506 let mut request = Request::new(Body::empty());
507 request
508 .extensions_mut()
509 .insert(MaxConnectionAgeConnectInfo {
510 active_requests: Arc::clone(&active_requests),
511 });
512
513 assert_eq!(active_requests.load(Ordering::Acquire), 0);
514
515 let response = service
516 .call(request)
517 .await
518 .expect("service call should succeed");
519
520 assert_eq!(active_requests.load(Ordering::Acquire), 1);
521
522 drop(response);
523
524 assert_eq!(active_requests.load(Ordering::Acquire), 0);
525 }
526
527 #[tokio::test]
528 async fn max_connection_age_service_tracks_active_requests_per_connection() {
529 let first_connection_active_requests = Arc::new(AtomicUsize::new(0));
530 let second_connection_active_requests = Arc::new(AtomicUsize::new(0));
531 let mut service = MaxConnectionAgeService {
532 service: EmptyBodyService,
533 };
534 let mut first_request = Request::new(Body::empty());
535 first_request
536 .extensions_mut()
537 .insert(MaxConnectionAgeConnectInfo {
538 active_requests: Arc::clone(&first_connection_active_requests),
539 });
540 let mut second_request = Request::new(Body::empty());
541 second_request
542 .extensions_mut()
543 .insert(MaxConnectionAgeConnectInfo {
544 active_requests: Arc::clone(&second_connection_active_requests),
545 });
546
547 let first_response = service
548 .call(first_request)
549 .await
550 .expect("first service call should succeed");
551
552 assert_eq!(first_connection_active_requests.load(Ordering::Acquire), 1);
553 assert_eq!(second_connection_active_requests.load(Ordering::Acquire), 0);
554
555 let second_response = service
556 .call(second_request)
557 .await
558 .expect("second service call should succeed");
559
560 assert_eq!(first_connection_active_requests.load(Ordering::Acquire), 1);
561 assert_eq!(second_connection_active_requests.load(Ordering::Acquire), 1);
562
563 drop(second_response);
564
565 assert_eq!(first_connection_active_requests.load(Ordering::Acquire), 1);
566 assert_eq!(second_connection_active_requests.load(Ordering::Acquire), 0);
567
568 drop(first_response);
569
570 assert_eq!(first_connection_active_requests.load(Ordering::Acquire), 0);
571 assert_eq!(second_connection_active_requests.load(Ordering::Acquire), 0);
572 }
573
574 #[tokio::test]
575 async fn max_connection_age_state_stops_reads_at_deadline_before_writes() {
576 let mut state = MaxConnectionAgeState::new(Some(Duration::from_millis(1)));
577 let active_requests = state.active_requests();
578 active_requests.fetch_add(1, Ordering::AcqRel);
579
580 sleep(Duration::from_millis(10)).await;
581
582 let waker = futures::task::noop_waker_ref();
583 let mut cx = Context::from_waker(waker);
584
585 assert!(state.is_read_expired_for_test(&mut cx));
586 assert!(!state.is_write_expired_for_test(&mut cx));
587
588 active_requests.fetch_sub(1, Ordering::AcqRel);
589
590 assert!(state.is_write_expired_for_test(&mut cx));
591 }
592}