Skip to main content

vector/sources/util/grpc/
mod.rs

1use std::{
2    convert::Infallible,
3    net::SocketAddr,
4    pin::Pin,
5    sync::{
6        Arc,
7        atomic::{AtomicUsize, Ordering},
8    },
9    task::{Context, Poll},
10    time::Duration,
11};
12
13use futures::{FutureExt, StreamExt, future::BoxFuture};
14use http::{HeaderMap, Request, Response};
15use hyper::{Body, body::HttpBody};
16use pin_project::pin_project;
17use tokio::{
18    io::{AsyncRead, AsyncWrite, ReadBuf},
19    net::TcpStream,
20    time::{Sleep, sleep},
21};
22use tonic::{
23    body::BoxBody,
24    server::NamedService,
25    transport::server::{Connected, Routes, Server},
26};
27use tower::{Layer, Service};
28use tower_http::{
29    classify::{GrpcErrorsAsFailures, SharedClassifier},
30    trace::TraceLayer,
31};
32use tracing::Span;
33
34use crate::{
35    internal_events::{GrpcServerRequestReceived, GrpcServerResponseSent},
36    shutdown::{ShutdownSignal, ShutdownSignalToken},
37    tls::{MaybeTlsIncomingStream, MaybeTlsSettings},
38};
39use vector_lib::configurable::configurable_component;
40
41mod decompression;
42pub use self::decompression::{DecompressionAndMetrics, DecompressionAndMetricsLayer};
43
44#[cfg(test)]
45static MAX_CONNECTION_AGE_CONNECTION_OBSERVATIONS: std::sync::Mutex<Vec<SocketAddr>> =
46    std::sync::Mutex::new(Vec::new());
47
48#[cfg(all(test, feature = "sources-vector", feature = "sinks-vector"))]
49pub(crate) mod test_support {
50    use std::net::SocketAddr;
51
52    use super::MAX_CONNECTION_AGE_CONNECTION_OBSERVATIONS;
53
54    pub(crate) fn reset_max_connection_age_connection_observations() {
55        MAX_CONNECTION_AGE_CONNECTION_OBSERVATIONS
56            .lock()
57            .unwrap()
58            .clear();
59    }
60
61    pub(crate) fn max_connection_age_connection_observations() -> Vec<SocketAddr> {
62        MAX_CONNECTION_AGE_CONNECTION_OBSERVATIONS
63            .lock()
64            .unwrap()
65            .clone()
66    }
67}
68
69/// Configuration of gRPC server keepalive parameters.
70#[configurable_component]
71#[derive(Clone, Debug, Default, PartialEq, Eq)]
72#[serde(deny_unknown_fields)]
73pub struct GrpcKeepaliveConfig {
74    /// The maximum amount of time a connection may exist before the server closes it.
75    ///
76    /// When unset, connections are not closed based on age.
77    #[serde(default)]
78    #[configurable(metadata(docs::examples = 300))]
79    #[configurable(metadata(docs::type_unit = "seconds"))]
80    #[configurable(metadata(docs::human_name = "Maximum Connection Age"))]
81    pub max_connection_age_secs: Option<u64>,
82
83    /// The grace period added to `max_connection_age_secs` before the server closes the connection.
84    ///
85    /// This setting only applies when `max_connection_age_secs` is set.
86    #[serde(default)]
87    #[configurable(metadata(docs::examples = 30))]
88    #[configurable(metadata(docs::type_unit = "seconds"))]
89    #[configurable(metadata(docs::human_name = "Maximum Connection Age Grace"))]
90    pub max_connection_age_grace_secs: Option<u64>,
91}
92
93impl GrpcKeepaliveConfig {
94    fn max_connection_lifetime(&self) -> Option<Duration> {
95        self.max_connection_age_secs.map(|max_connection_age_secs| {
96            let age = Duration::from_secs(max_connection_age_secs);
97            let grace = self
98                .max_connection_age_grace_secs
99                .map(Duration::from_secs)
100                .unwrap_or_default();
101
102            age.checked_add(grace).unwrap_or(Duration::MAX)
103        })
104    }
105}
106
107struct MaxConnectionAgeIo {
108    inner: MaybeTlsIncomingStream<TcpStream>,
109    state: MaxConnectionAgeState,
110}
111
112impl MaxConnectionAgeIo {
113    fn new(inner: MaybeTlsIncomingStream<TcpStream>, lifetime: Option<Duration>) -> Self {
114        #[cfg(test)]
115        if lifetime.is_some() {
116            MAX_CONNECTION_AGE_CONNECTION_OBSERVATIONS
117                .lock()
118                .unwrap()
119                .push(inner.peer_addr());
120        }
121
122        Self {
123            inner,
124            state: MaxConnectionAgeState::new(lifetime),
125        }
126    }
127}
128
129struct MaxConnectionAgeState {
130    deadline: Option<Pin<Box<Sleep>>>,
131    read_expired: bool,
132    active_requests: Arc<AtomicUsize>,
133}
134
135impl MaxConnectionAgeState {
136    fn new(lifetime: Option<Duration>) -> Self {
137        Self {
138            deadline: lifetime.map(|lifetime| Box::pin(sleep(lifetime))),
139            read_expired: false,
140            active_requests: Arc::new(AtomicUsize::new(0)),
141        }
142    }
143
144    fn is_read_expired(&mut self, cx: &mut Context<'_>) -> bool {
145        if self.read_expired {
146            return true;
147        }
148
149        self.read_expired = self
150            .deadline
151            .as_mut()
152            .is_some_and(|deadline| deadline.as_mut().poll(cx).is_ready());
153
154        self.read_expired
155    }
156
157    fn is_write_expired(&mut self, cx: &mut Context<'_>) -> bool {
158        self.is_read_expired(cx) && self.active_requests.load(Ordering::Acquire) == 0
159    }
160
161    fn active_requests(&self) -> Arc<AtomicUsize> {
162        Arc::clone(&self.active_requests)
163    }
164
165    #[cfg(test)]
166    fn is_read_expired_for_test(&mut self, cx: &mut Context<'_>) -> bool {
167        self.is_read_expired(cx)
168    }
169
170    #[cfg(test)]
171    fn is_write_expired_for_test(&mut self, cx: &mut Context<'_>) -> bool {
172        self.is_write_expired(cx)
173    }
174}
175
176impl AsyncRead for MaxConnectionAgeIo {
177    fn poll_read(
178        self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        buf: &mut ReadBuf<'_>,
181    ) -> Poll<std::io::Result<()>> {
182        let this = self.get_mut();
183        if this.state.is_read_expired(cx) {
184            Poll::Ready(Ok(()))
185        } else {
186            Pin::new(&mut this.inner).poll_read(cx, buf)
187        }
188    }
189}
190
191impl AsyncWrite for MaxConnectionAgeIo {
192    fn poll_write(
193        self: Pin<&mut Self>,
194        cx: &mut Context<'_>,
195        buf: &[u8],
196    ) -> Poll<std::io::Result<usize>> {
197        let this = self.get_mut();
198        if this.state.is_write_expired(cx) {
199            Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()))
200        } else {
201            Pin::new(&mut this.inner).poll_write(cx, buf)
202        }
203    }
204
205    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
206        let this = self.get_mut();
207        if this.state.is_write_expired(cx) {
208            Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()))
209        } else {
210            Pin::new(&mut this.inner).poll_flush(cx)
211        }
212    }
213
214    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
215        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
216    }
217}
218
219impl Connected for MaxConnectionAgeIo {
220    type ConnectInfo = MaxConnectionAgeConnectInfo;
221
222    fn connect_info(&self) -> Self::ConnectInfo {
223        MaxConnectionAgeConnectInfo {
224            active_requests: self.state.active_requests(),
225        }
226    }
227}
228
229#[derive(Clone, Debug)]
230struct MaxConnectionAgeConnectInfo {
231    active_requests: Arc<AtomicUsize>,
232}
233
234#[derive(Clone)]
235struct MaxConnectionAgeLayer;
236
237impl MaxConnectionAgeLayer {
238    const fn new() -> Self {
239        Self
240    }
241}
242
243impl<S> Layer<S> for MaxConnectionAgeLayer {
244    type Service = MaxConnectionAgeService<S>;
245
246    fn layer(&self, service: S) -> Self::Service {
247        MaxConnectionAgeService { service }
248    }
249}
250
251#[derive(Clone)]
252struct MaxConnectionAgeService<S> {
253    service: S,
254}
255
256impl<S> NamedService for MaxConnectionAgeService<S>
257where
258    S: NamedService,
259{
260    const NAME: &'static str = S::NAME;
261}
262
263struct ActiveRequestGuard {
264    active_requests: Arc<AtomicUsize>,
265}
266
267impl ActiveRequestGuard {
268    fn new(active_requests: Arc<AtomicUsize>) -> Self {
269        active_requests.fetch_add(1, Ordering::AcqRel);
270        Self { active_requests }
271    }
272}
273
274impl Drop for ActiveRequestGuard {
275    fn drop(&mut self) {
276        self.active_requests.fetch_sub(1, Ordering::AcqRel);
277    }
278}
279
280#[pin_project]
281struct MaxConnectionAgeBody<B> {
282    #[pin]
283    inner: B,
284    _guard: Option<ActiveRequestGuard>,
285}
286
287impl<B> MaxConnectionAgeBody<B> {
288    const fn new(inner: B, guard: Option<ActiveRequestGuard>) -> Self {
289        Self {
290            inner,
291            _guard: guard,
292        }
293    }
294}
295
296impl<B> HttpBody for MaxConnectionAgeBody<B>
297where
298    B: HttpBody,
299{
300    type Data = B::Data;
301    type Error = B::Error;
302
303    fn poll_data(
304        self: Pin<&mut Self>,
305        cx: &mut Context<'_>,
306    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
307        self.project().inner.poll_data(cx)
308    }
309
310    fn poll_trailers(
311        self: Pin<&mut Self>,
312        cx: &mut Context<'_>,
313    ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
314        self.project().inner.poll_trailers(cx)
315    }
316
317    fn is_end_stream(&self) -> bool {
318        self.inner.is_end_stream()
319    }
320
321    fn size_hint(&self) -> hyper::body::SizeHint {
322        self.inner.size_hint()
323    }
324}
325
326impl<S, B> Service<Request<Body>> for MaxConnectionAgeService<S>
327where
328    S: Service<Request<Body>, Response = Response<B>> + Clone + Send + 'static,
329    S::Future: Send + 'static,
330    B: HttpBody + Send + 'static,
331{
332    type Response = Response<MaxConnectionAgeBody<B>>;
333    type Error = S::Error;
334    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
335
336    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
337        self.service.poll_ready(cx)
338    }
339
340    fn call(&mut self, req: Request<Body>) -> Self::Future {
341        let guard = req
342            .extensions()
343            .get::<MaxConnectionAgeConnectInfo>()
344            .map(|connect_info| ActiveRequestGuard::new(Arc::clone(&connect_info.active_requests)));
345        let future = self.service.call(req);
346
347        async move {
348            future
349                .await
350                .map(|response| response.map(|body| MaxConnectionAgeBody::new(body, guard)))
351        }
352        .boxed()
353    }
354}
355
356pub async fn run_grpc_server<S>(
357    address: SocketAddr,
358    tls_settings: MaybeTlsSettings,
359    service: S,
360    keepalive: GrpcKeepaliveConfig,
361    shutdown: ShutdownSignal,
362) -> crate::Result<()>
363where
364    S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
365        + NamedService
366        + Clone
367        + Send
368        + 'static,
369    S::Future: Send + 'static,
370{
371    let span = Span::current();
372    let (tx, rx) = tokio::sync::oneshot::channel::<ShutdownSignalToken>();
373    let listener = tls_settings.bind(&address).await?;
374    let max_connection_lifetime = keepalive.max_connection_lifetime();
375    let stream = listener
376        .accept_stream()
377        .map(move |stream| stream.map(|io| MaxConnectionAgeIo::new(io, max_connection_lifetime)));
378
379    info!(%address, "Building gRPC server.");
380
381    Server::builder()
382        .layer(MaxConnectionAgeLayer::new())
383        .layer(build_grpc_trace_layer(span.clone()))
384        // This layer explicitly decompresses payloads, if compressed, and reports the number of message bytes we've
385        // received if the message is processed successfully, aka `BytesReceived`. We do this because otherwise the only
386        // access we have is either the event-specific bytes (the in-memory representation) or the raw bytes over the
387        // wire prior to decompression... and if that case, any bytes at all, not just the ones we successfully process.
388        //
389        // The weaving of `tonic`, `axum`, `tower`, and `hyper` is fairly complex and there currently exists no way to
390        // use independent `tower` layers when the request body itself (the body type, not the actual bytes) must be
391        // modified or wrapped.. so instead of a cleaner design, we're opting here to bake it all together until the
392        // crates are sufficiently flexible for us to craft a better design.
393        .layer(DecompressionAndMetricsLayer)
394        .add_service(service)
395        .serve_with_incoming_shutdown(stream, shutdown.map(|token| tx.send(token).unwrap()))
396        .await?;
397
398    drop(rx.await);
399
400    Ok(())
401}
402
403// This is a bit of a ugly hack to allow us to run two services on the same port.
404// I just don't know how to convert the generic type with associated types into a Vec<Box<trait object>>.
405pub async fn run_grpc_server_with_routes(
406    address: SocketAddr,
407    tls_settings: MaybeTlsSettings,
408    routes: Routes,
409    keepalive: GrpcKeepaliveConfig,
410    shutdown: ShutdownSignal,
411) -> crate::Result<()> {
412    let span = Span::current();
413    let (tx, rx) = tokio::sync::oneshot::channel::<ShutdownSignalToken>();
414    let listener = tls_settings.bind(&address).await?;
415    let max_connection_lifetime = keepalive.max_connection_lifetime();
416    let stream = listener
417        .accept_stream()
418        .map(move |stream| stream.map(|io| MaxConnectionAgeIo::new(io, max_connection_lifetime)));
419
420    info!(%address, "Building gRPC server.");
421
422    Server::builder()
423        .layer(MaxConnectionAgeLayer::new())
424        .layer(build_grpc_trace_layer(span.clone()))
425        .layer(DecompressionAndMetricsLayer)
426        .add_routes(routes)
427        .serve_with_incoming_shutdown(stream, shutdown.map(|token| tx.send(token).unwrap()))
428        .await?;
429
430    drop(rx.await);
431
432    Ok(())
433}
434
435/// Builds a [TraceLayer] configured for a gRPC server.
436///
437/// This layer emits gPRC specific telemetry for messages received/sent and handler duration.
438pub fn build_grpc_trace_layer(
439    span: Span,
440) -> TraceLayer<
441    SharedClassifier<GrpcErrorsAsFailures>,
442    impl Fn(&Request<Body>) -> Span + Clone,
443    impl Fn(&Request<Body>, &Span) + Clone,
444    impl Fn(&Response<BoxBody>, Duration, &Span) + Clone,
445    (),
446    (),
447    (),
448> {
449    TraceLayer::new_for_grpc()
450        .make_span_with(move |request: &Request<Body>| {
451            // The path is defined as “/” {service name} “/” {method name}.
452            let mut path = request.uri().path().split('/');
453            let service = path.nth(1).unwrap_or("_unknown");
454            let method = path.next().unwrap_or("_unknown");
455
456            // This is an error span so that the labels are always present for metrics.
457            error_span!(
458               parent: &span,
459               "grpc-request",
460               grpc_service = service,
461               grpc_method = method,
462            )
463        })
464        .on_request(Box::new(|_request: &Request<Body>, _span: &Span| {
465            emit!(GrpcServerRequestReceived);
466        }))
467        .on_response(
468            |response: &Response<BoxBody>, latency: Duration, _span: &Span| {
469                emit!(GrpcServerResponseSent { response, latency });
470            },
471        )
472        .on_failure(())
473        .on_body_chunk(())
474        .on_eos(())
475}
476
477#[cfg(test)]
478mod tests {
479    use std::future::{Ready, ready};
480
481    use super::*;
482
483    #[derive(Clone)]
484    struct EmptyBodyService;
485
486    impl Service<Request<Body>> for EmptyBodyService {
487        type Response = Response<Body>;
488        type Error = Infallible;
489        type Future = Ready<Result<Self::Response, Self::Error>>;
490
491        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
492            Poll::Ready(Ok(()))
493        }
494
495        fn call(&mut self, _req: Request<Body>) -> Self::Future {
496            ready(Ok(Response::new(Body::empty())))
497        }
498    }
499
500    #[tokio::test]
501    async fn max_connection_age_service_tracks_response_body_until_drop() {
502        let active_requests = Arc::new(AtomicUsize::new(0));
503        let mut service = MaxConnectionAgeService {
504            service: EmptyBodyService,
505        };
506        let mut request = Request::new(Body::empty());
507        request
508            .extensions_mut()
509            .insert(MaxConnectionAgeConnectInfo {
510                active_requests: Arc::clone(&active_requests),
511            });
512
513        assert_eq!(active_requests.load(Ordering::Acquire), 0);
514
515        let response = service
516            .call(request)
517            .await
518            .expect("service call should succeed");
519
520        assert_eq!(active_requests.load(Ordering::Acquire), 1);
521
522        drop(response);
523
524        assert_eq!(active_requests.load(Ordering::Acquire), 0);
525    }
526
527    #[tokio::test]
528    async fn max_connection_age_service_tracks_active_requests_per_connection() {
529        let first_connection_active_requests = Arc::new(AtomicUsize::new(0));
530        let second_connection_active_requests = Arc::new(AtomicUsize::new(0));
531        let mut service = MaxConnectionAgeService {
532            service: EmptyBodyService,
533        };
534        let mut first_request = Request::new(Body::empty());
535        first_request
536            .extensions_mut()
537            .insert(MaxConnectionAgeConnectInfo {
538                active_requests: Arc::clone(&first_connection_active_requests),
539            });
540        let mut second_request = Request::new(Body::empty());
541        second_request
542            .extensions_mut()
543            .insert(MaxConnectionAgeConnectInfo {
544                active_requests: Arc::clone(&second_connection_active_requests),
545            });
546
547        let first_response = service
548            .call(first_request)
549            .await
550            .expect("first service call should succeed");
551
552        assert_eq!(first_connection_active_requests.load(Ordering::Acquire), 1);
553        assert_eq!(second_connection_active_requests.load(Ordering::Acquire), 0);
554
555        let second_response = service
556            .call(second_request)
557            .await
558            .expect("second service call should succeed");
559
560        assert_eq!(first_connection_active_requests.load(Ordering::Acquire), 1);
561        assert_eq!(second_connection_active_requests.load(Ordering::Acquire), 1);
562
563        drop(second_response);
564
565        assert_eq!(first_connection_active_requests.load(Ordering::Acquire), 1);
566        assert_eq!(second_connection_active_requests.load(Ordering::Acquire), 0);
567
568        drop(first_response);
569
570        assert_eq!(first_connection_active_requests.load(Ordering::Acquire), 0);
571        assert_eq!(second_connection_active_requests.load(Ordering::Acquire), 0);
572    }
573
574    #[tokio::test]
575    async fn max_connection_age_state_stops_reads_at_deadline_before_writes() {
576        let mut state = MaxConnectionAgeState::new(Some(Duration::from_millis(1)));
577        let active_requests = state.active_requests();
578        active_requests.fetch_add(1, Ordering::AcqRel);
579
580        sleep(Duration::from_millis(10)).await;
581
582        let waker = futures::task::noop_waker_ref();
583        let mut cx = Context::from_waker(waker);
584
585        assert!(state.is_read_expired_for_test(&mut cx));
586        assert!(!state.is_write_expired_for_test(&mut cx));
587
588        active_requests.fetch_sub(1, Ordering::AcqRel);
589
590        assert!(state.is_write_expired_for_test(&mut cx));
591    }
592}