1#![allow(missing_docs)]
2use std::{
3 collections::HashMap,
4 fmt,
5 net::SocketAddr,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use futures::future::BoxFuture;
11use headers::{Authorization, HeaderMapExt};
12use http::{
13 HeaderMap, Request, Response, Uri, Version, header::HeaderValue, request::Builder,
14 uri::InvalidUri,
15};
16use hyper::{
17 body::{Body, HttpBody},
18 client,
19 client::{Client, HttpConnector},
20};
21use hyper_openssl::HttpsConnector;
22use hyper_proxy::ProxyConnector;
23use rand::Rng;
24use serde_with::serde_as;
25use snafu::{ResultExt, Snafu};
26use tokio::time::Instant;
27use tower::{Layer, Service};
28use tower_http::{
29 classify::{ServerErrorsAsFailures, SharedClassifier},
30 trace::TraceLayer,
31};
32use tracing::{Instrument, Span};
33use vector_lib::{configurable::configurable_component, sensitive_string::SensitiveString};
34
35#[cfg(feature = "aws-core")]
36use crate::aws::AwsAuthentication;
37use crate::{
38 config::ProxyConfig,
39 internal_events::{HttpServerRequestReceived, HttpServerResponseSent, http_client},
40 tls::{MaybeTlsSettings, TlsError, tls_connector_builder},
41};
42
43pub mod status {
44 pub const FORBIDDEN: u16 = 403;
45 pub const NOT_FOUND: u16 = 404;
46 pub const TOO_MANY_REQUESTS: u16 = 429;
47}
48
49#[derive(Debug, Snafu)]
50#[snafu(visibility(pub(crate)))]
51pub enum HttpError {
52 #[snafu(display("Failed to build TLS connector: {}", source))]
53 BuildTlsConnector { source: TlsError },
54 #[snafu(display("Failed to build HTTPS connector: {}", source))]
55 MakeHttpsConnector { source: openssl::error::ErrorStack },
56 #[snafu(display("Failed to build Proxy connector: {}", source))]
57 MakeProxyConnector { source: InvalidUri },
58 #[snafu(display("Failed to make HTTP(S) request: {}", source))]
59 CallRequest { source: hyper::Error },
60 #[snafu(display("Failed to build HTTP request: {}", source))]
61 BuildRequest { source: http::Error },
62}
63
64impl HttpError {
65 pub const fn is_retriable(&self) -> bool {
66 match self {
67 HttpError::BuildRequest { .. } | HttpError::MakeProxyConnector { .. } => false,
68 HttpError::CallRequest { .. }
69 | HttpError::BuildTlsConnector { .. }
70 | HttpError::MakeHttpsConnector { .. } => true,
71 }
72 }
73}
74
75pub type HttpClientFuture = <HttpClient as Service<http::Request<Body>>>::Future;
76type HttpProxyConnector = ProxyConnector<HttpsConnector<HttpConnector>>;
77
78pub struct HttpClient<B = Body> {
79 client: Client<HttpProxyConnector, B>,
80 user_agent: HeaderValue,
81 proxy_connector: HttpProxyConnector,
82}
83
84impl<B> HttpClient<B>
85where
86 B: fmt::Debug + HttpBody + Send + 'static,
87 B::Data: Send,
88 B::Error: Into<crate::Error>,
89{
90 pub fn new(
91 tls_settings: impl Into<MaybeTlsSettings>,
92 proxy_config: &ProxyConfig,
93 ) -> Result<HttpClient<B>, HttpError> {
94 HttpClient::new_with_custom_client(tls_settings, proxy_config, &mut Client::builder())
95 }
96
97 pub fn new_with_custom_client(
98 tls_settings: impl Into<MaybeTlsSettings>,
99 proxy_config: &ProxyConfig,
100 client_builder: &mut client::Builder,
101 ) -> Result<HttpClient<B>, HttpError> {
102 let proxy_connector = build_proxy_connector(tls_settings.into(), proxy_config)?;
103 let client = client_builder.build(proxy_connector.clone());
104
105 let app_name = crate::get_app_name();
106 let version = crate::get_version();
107 let user_agent = HeaderValue::from_str(&format!("{app_name}/{version}"))
108 .expect("Invalid header value for user-agent!");
109
110 Ok(HttpClient {
111 client,
112 user_agent,
113 proxy_connector,
114 })
115 }
116
117 pub fn send(
118 &self,
119 mut request: Request<B>,
120 ) -> BoxFuture<'static, Result<http::Response<Body>, HttpError>> {
121 let span = tracing::info_span!("http");
122 let _enter = span.enter();
123
124 default_request_headers(&mut request, &self.user_agent);
125 self.maybe_add_proxy_headers(&mut request);
126
127 emit!(http_client::AboutToSendHttpRequest { request: &request });
128
129 let response = self.client.request(request);
130
131 let fut = async move {
132 let before = std::time::Instant::now();
135
136 let response_result = response.await;
138
139 let roundtrip = before.elapsed();
142
143 let response = response_result
145 .inspect_err(|error| {
146 emit!(http_client::GotHttpWarning { error, roundtrip });
148 })
149 .context(CallRequestSnafu)?;
150
151 emit!(http_client::GotHttpResponse {
153 response: &response,
154 roundtrip
155 });
156 Ok(response)
157 }
158 .instrument(span.clone().or_current());
159
160 Box::pin(fut)
161 }
162
163 fn maybe_add_proxy_headers(&self, request: &mut Request<B>) {
164 if let Some(proxy_headers) = self.proxy_connector.http_headers(request.uri()) {
165 for (k, v) in proxy_headers {
166 let request_headers = request.headers_mut();
167 if !request_headers.contains_key(k) {
168 request_headers.insert(k, v.into());
169 }
170 }
171 }
172 }
173}
174
175pub fn build_proxy_connector(
176 tls_settings: MaybeTlsSettings,
177 proxy_config: &ProxyConfig,
178) -> Result<ProxyConnector<HttpsConnector<HttpConnector>>, HttpError> {
179 let tls = tls_connector_builder(&tls_settings)
181 .context(BuildTlsConnectorSnafu)?
182 .build();
183 let https = build_tls_connector(tls_settings)?;
184 let mut proxy = ProxyConnector::new(https).unwrap();
185 proxy.set_tls(Some(tls));
188 proxy_config
189 .configure(&mut proxy)
190 .context(MakeProxyConnectorSnafu)?;
191 Ok(proxy)
192}
193
194pub fn build_tls_connector(
195 tls_settings: MaybeTlsSettings,
196) -> Result<HttpsConnector<HttpConnector>, HttpError> {
197 let mut http = HttpConnector::new();
198 http.enforce_http(false);
199
200 let tls = tls_connector_builder(&tls_settings).context(BuildTlsConnectorSnafu)?;
201 let mut https = HttpsConnector::with_connector(http, tls).context(MakeHttpsConnectorSnafu)?;
202
203 let settings = tls_settings.tls().cloned();
204 https.set_callback(move |c, _uri| {
205 if let Some(settings) = &settings {
206 settings.apply_connect_configuration(c)
207 } else {
208 Ok(())
209 }
210 });
211 Ok(https)
212}
213
214fn default_request_headers<B>(request: &mut Request<B>, user_agent: &HeaderValue) {
215 if !request.headers().contains_key("User-Agent") {
216 request
217 .headers_mut()
218 .insert("User-Agent", user_agent.clone());
219 }
220
221 if !request.headers().contains_key("Accept-Encoding") {
222 request
225 .headers_mut()
226 .insert("Accept-Encoding", HeaderValue::from_static("identity"));
227 }
228}
229
230impl<B> Service<Request<B>> for HttpClient<B>
231where
232 B: fmt::Debug + HttpBody + Send + 'static,
233 B::Data: Send,
234 B::Error: Into<crate::Error> + Send,
235{
236 type Response = http::Response<Body>;
237 type Error = HttpError;
238 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
239
240 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
241 Poll::Ready(Ok(()))
242 }
243
244 fn call(&mut self, request: Request<B>) -> Self::Future {
245 self.send(request)
246 }
247}
248
249impl<B> Clone for HttpClient<B> {
250 fn clone(&self) -> Self {
251 Self {
252 client: self.client.clone(),
253 user_agent: self.user_agent.clone(),
254 proxy_connector: self.proxy_connector.clone(),
255 }
256 }
257}
258
259impl<B> fmt::Debug for HttpClient<B> {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 f.debug_struct("HttpClient")
262 .field("client", &self.client)
263 .field("user_agent", &self.user_agent)
264 .finish()
265 }
266}
267
268#[configurable_component]
273#[derive(Clone, Debug, Eq, PartialEq)]
274#[serde(deny_unknown_fields, rename_all = "snake_case", tag = "strategy")]
275#[configurable(metadata(docs::enum_tag_description = "The authentication strategy to use."))]
276pub enum Auth {
277 Basic {
283 #[configurable(metadata(docs::examples = "${USERNAME}"))]
285 #[configurable(metadata(docs::examples = "username"))]
286 user: String,
287
288 #[configurable(metadata(docs::examples = "${PASSWORD}"))]
290 #[configurable(metadata(docs::examples = "password"))]
291 password: SensitiveString,
292 },
293
294 Bearer {
298 token: SensitiveString,
300 },
301
302 #[cfg(feature = "aws-core")]
303 Aws {
305 auth: AwsAuthentication,
307
308 service: String,
310 },
311
312 Custom {
314 #[configurable(metadata(docs::examples = "${AUTH_HEADER_VALUE}"))]
316 #[configurable(metadata(docs::examples = "CUSTOM_PREFIX ${TOKEN}"))]
317 value: String,
318 },
319}
320
321pub trait MaybeAuth: Sized {
322 fn choose_one(&self, other: &Self) -> crate::Result<Self>;
323}
324
325impl MaybeAuth for Option<Auth> {
326 fn choose_one(&self, other: &Self) -> crate::Result<Self> {
327 if self.is_some() && other.is_some() {
328 Err("Two authorization credentials was provided.".into())
329 } else {
330 Ok(self.clone().or_else(|| other.clone()))
331 }
332 }
333}
334
335impl Auth {
336 pub fn apply<B>(&self, req: &mut Request<B>) {
337 self.apply_headers_map(req.headers_mut())
338 }
339
340 pub fn apply_builder(&self, mut builder: Builder) -> Builder {
341 if let Some(map) = builder.headers_mut() {
342 self.apply_headers_map(map)
343 }
344 builder
345 }
346
347 pub fn apply_headers_map(&self, map: &mut HeaderMap) {
348 match &self {
349 Auth::Basic { user, password } => {
350 let auth = Authorization::basic(user.as_str(), password.inner());
351 map.typed_insert(auth);
352 }
353 Auth::Bearer { token } => match Authorization::bearer(token.inner()) {
354 Ok(auth) => map.typed_insert(auth),
355 Err(error) => error!(message = "Invalid bearer token.", token = %token, %error),
356 },
357 Auth::Custom { value } => {
358 match HeaderValue::from_str(value) {
361 Ok(header_val) => {
362 map.insert(http::header::AUTHORIZATION, header_val);
363 }
364 Err(error) => {
365 error!(message = "Invalid custom auth header value.", value = %value, %error)
366 }
367 }
368 }
369 #[cfg(feature = "aws-core")]
370 _ => {}
371 }
372 }
373}
374
375pub fn get_http_scheme_from_uri(uri: &Uri) -> &'static str {
376 uri.scheme_str().map_or("http", |scheme| match scheme {
379 "http" => "http",
380 "https" => "https",
381 s => panic!("invalid URI scheme for HTTP client: {s}"),
386 })
387}
388
389pub fn build_http_trace_layer<T, U>(
393 span: Span,
394) -> TraceLayer<
395 SharedClassifier<ServerErrorsAsFailures>,
396 impl Fn(&Request<T>) -> Span + Clone,
397 impl Fn(&Request<T>, &Span) + Clone,
398 impl Fn(&Response<U>, Duration, &Span) + Clone,
399 (),
400 (),
401 (),
402> {
403 TraceLayer::new_for_http()
404 .make_span_with(move |request: &Request<T>| {
405 error_span!(
407 parent: &span,
408 "http-request",
409 method = %request.method(),
410 path = %request.uri().path(),
411 )
412 })
413 .on_request(Box::new(|_request: &Request<T>, _span: &Span| {
414 emit!(HttpServerRequestReceived);
415 }))
416 .on_response(|response: &Response<U>, latency: Duration, _span: &Span| {
417 emit!(HttpServerResponseSent { response, latency });
418 })
419 .on_failure(())
420 .on_body_chunk(())
421 .on_eos(())
422}
423
424#[serde_as]
426#[configurable_component]
427#[derive(Clone, Debug, PartialEq)]
428#[serde(deny_unknown_fields)]
429pub struct KeepaliveConfig {
430 #[serde(default = "default_max_connection_age")]
440 #[configurable(metadata(docs::examples = 600))]
441 #[configurable(metadata(docs::type_unit = "seconds"))]
442 #[configurable(metadata(docs::human_name = "Maximum Connection Age"))]
443 pub max_connection_age_secs: Option<u64>,
444
445 #[serde(default = "default_max_connection_age_jitter_factor")]
450 #[configurable(validation(range(min = 0.0, max = 1.0)))]
451 pub max_connection_age_jitter_factor: f64,
452}
453
454const fn default_max_connection_age() -> Option<u64> {
455 Some(300) }
457
458const fn default_max_connection_age_jitter_factor() -> f64 {
459 0.1
460}
461
462impl Default for KeepaliveConfig {
463 fn default() -> Self {
464 Self {
465 max_connection_age_secs: default_max_connection_age(),
466 max_connection_age_jitter_factor: default_max_connection_age_jitter_factor(),
467 }
468 }
469}
470
471pub struct MaxConnectionAgeLayer {
481 start_reference: Instant,
482 max_connection_age: Duration,
483 peer_addr: SocketAddr,
484}
485
486impl MaxConnectionAgeLayer {
487 pub fn new(max_connection_age: Duration, jitter_factor: f64, peer_addr: SocketAddr) -> Self {
488 Self {
489 start_reference: Instant::now(),
490 max_connection_age: Self::jittered_duration(max_connection_age, jitter_factor),
491 peer_addr,
492 }
493 }
494
495 fn jittered_duration(duration: Duration, jitter_factor: f64) -> Duration {
496 let jitter_factor = jitter_factor.clamp(0.0, 1.0);
498 let mut rng = rand::rng();
500 let random_jitter_factor = rng.random_range(-jitter_factor..=jitter_factor) + 1.;
501 duration.mul_f64(random_jitter_factor)
502 }
503}
504
505impl<S> Layer<S> for MaxConnectionAgeLayer
506where
507 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
508 S::Future: Send + 'static,
509{
510 type Service = MaxConnectionAgeService<S>;
511
512 fn layer(&self, service: S) -> Self::Service {
513 MaxConnectionAgeService {
514 service,
515 start_reference: self.start_reference,
516 max_connection_age: self.max_connection_age,
517 peer_addr: self.peer_addr,
518 }
519 }
520}
521
522#[derive(Clone)]
532pub struct MaxConnectionAgeService<S> {
533 service: S,
534 start_reference: Instant,
535 max_connection_age: Duration,
536 peer_addr: SocketAddr,
537}
538
539impl<S, E> Service<Request<Body>> for MaxConnectionAgeService<S>
540where
541 S: Service<Request<Body>, Response = Response<Body>, Error = E> + Clone + Send + 'static,
542 S::Future: Send + 'static,
543{
544 type Response = S::Response;
545 type Error = E;
546 type Future = BoxFuture<'static, Result<Self::Response, E>>;
547
548 fn poll_ready(
549 &mut self,
550 cx: &mut std::task::Context<'_>,
551 ) -> std::task::Poll<Result<(), Self::Error>> {
552 self.service.poll_ready(cx)
553 }
554
555 fn call(&mut self, req: Request<Body>) -> Self::Future {
556 let start_reference = self.start_reference;
557 let max_connection_age = self.max_connection_age;
558 let peer_addr = self.peer_addr;
559 let version = req.version();
560 let future = self.service.call(req);
561 Box::pin(async move {
562 let mut response = future.await?;
563 match version {
564 Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11
565 if start_reference.elapsed() >= max_connection_age =>
566 {
567 debug!(
568 message = "Closing connection due to max connection age.",
569 ?max_connection_age,
570 connection_age = ?start_reference.elapsed(),
571 ?peer_addr,
572 );
573 response.headers_mut().insert(
576 hyper::header::CONNECTION,
577 hyper::header::HeaderValue::from_static("close"),
578 );
579 }
580 Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => (),
581 Version::HTTP_2 => (),
583 Version::HTTP_3 => (),
585 _ => (),
586 }
587 Ok(response)
588 })
589 }
590}
591
592#[configurable_component]
594#[derive(Clone, Debug, Default, Eq, PartialEq)]
595#[serde(rename_all = "snake_case")]
596pub enum ParamType {
597 #[default]
599 String,
600 Vrl,
602}
603
604impl ParamType {
605 fn is_default(&self) -> bool {
606 *self == Self::default()
607 }
608}
609
610#[configurable_component]
613#[derive(Clone, Debug, Eq, PartialEq)]
614#[serde(untagged)]
615pub enum ParameterValue {
616 String(String),
618 Typed {
620 value: String,
622 #[serde(
624 default,
625 skip_serializing_if = "ParamType::is_default",
626 rename = "type"
627 )]
628 r#type: ParamType,
629 },
630}
631
632impl ParameterValue {
633 pub const fn is_vrl(&self) -> bool {
635 match self {
636 ParameterValue::String(_) => false,
637 ParameterValue::Typed { r#type, .. } => matches!(r#type, ParamType::Vrl),
638 }
639 }
640
641 #[allow(clippy::missing_const_for_fn)]
643 pub fn value(&self) -> &str {
644 match self {
645 ParameterValue::String(s) => s,
646 ParameterValue::Typed { value, .. } => value,
647 }
648 }
649
650 pub fn into_value(self) -> String {
652 match self {
653 ParameterValue::String(s) => s,
654 ParameterValue::Typed { value, .. } => value,
655 }
656 }
657}
658
659#[configurable_component]
661#[derive(Clone, Debug, Eq, PartialEq)]
662#[serde(untagged)]
663#[configurable(metadata(docs::enum_tag_description = "Query parameter value"))]
664pub enum QueryParameterValue {
665 SingleParam(ParameterValue),
667 MultiParams(Vec<ParameterValue>),
669}
670
671impl QueryParameterValue {
672 pub fn iter(&self) -> impl Iterator<Item = &ParameterValue> {
674 match self {
675 QueryParameterValue::SingleParam(param) => std::slice::from_ref(param).iter(),
676 QueryParameterValue::MultiParams(params) => params.iter(),
677 }
678 }
679
680 fn into_vec(self) -> Vec<ParameterValue> {
682 match self {
683 QueryParameterValue::SingleParam(param) => vec![param],
684 QueryParameterValue::MultiParams(params) => params,
685 }
686 }
687}
688
689impl IntoIterator for QueryParameterValue {
691 type Item = ParameterValue;
692 type IntoIter = std::vec::IntoIter<ParameterValue>;
693
694 fn into_iter(self) -> Self::IntoIter {
695 self.into_vec().into_iter()
696 }
697}
698
699pub type QueryParameters = HashMap<String, QueryParameterValue>;
700
701#[cfg(test)]
702mod tests {
703 use std::convert::Infallible;
704
705 use hyper::{Server, server::conn::AddrStream, service::make_service_fn};
706 use proptest::prelude::*;
707 use tower::ServiceBuilder;
708
709 use super::*;
710 use crate::test_util::addr::next_addr;
711
712 #[test]
713 fn test_default_request_headers_defaults() {
714 let user_agent = HeaderValue::from_static("vector");
715 let mut request = Request::post("http://example.com").body(()).unwrap();
716 default_request_headers(&mut request, &user_agent);
717 assert_eq!(
718 request.headers().get("Accept-Encoding"),
719 Some(&HeaderValue::from_static("identity")),
720 );
721 assert_eq!(request.headers().get("User-Agent"), Some(&user_agent));
722 }
723
724 #[test]
725 fn test_default_request_headers_does_not_overwrite() {
726 let mut request = Request::post("http://example.com")
727 .header("Accept-Encoding", "gzip")
728 .header("User-Agent", "foo")
729 .body(())
730 .unwrap();
731 default_request_headers(&mut request, &HeaderValue::from_static("vector"));
732 assert_eq!(
733 request.headers().get("Accept-Encoding"),
734 Some(&HeaderValue::from_static("gzip")),
735 );
736 assert_eq!(
737 request.headers().get("User-Agent"),
738 Some(&HeaderValue::from_static("foo"))
739 );
740 }
741
742 proptest! {
743 #[test]
744 fn test_jittered_duration(duration_in_secs in 0u64..120, jitter_factor in 0.0..1.0) {
745 let duration = Duration::from_secs(duration_in_secs);
746 let jittered_duration = MaxConnectionAgeLayer::jittered_duration(duration, jitter_factor);
747
748 if jitter_factor == 0.0 {
750 prop_assert_eq!(
752 jittered_duration,
753 duration,
754 "jittered_duration {:?} should be equal to duration {:?}",
755 jittered_duration,
756 duration,
757 );
758 } else if duration_in_secs > 0 {
759 let lower_bound = duration.mul_f64(1.0 - jitter_factor);
761 let upper_bound = duration.mul_f64(1.0 + jitter_factor);
762 prop_assert!(
763 jittered_duration >= lower_bound && jittered_duration <= upper_bound,
764 "jittered_duration {:?} should be between {:?} and {:?}",
765 jittered_duration,
766 lower_bound,
767 upper_bound,
768 );
769 } else {
770 prop_assert_eq!(
772 jittered_duration,
773 Duration::from_secs(0),
774 "jittered_duration {:?} should be equal to zero",
775 jittered_duration,
776 );
777 }
778 }
779 }
780
781 #[tokio::test]
782 async fn test_max_connection_age_service() {
783 tokio::time::pause();
784
785 let start_reference = Instant::now();
786 let max_connection_age = Duration::from_secs(1);
787 let mut service = MaxConnectionAgeService {
788 service: tower::service_fn(|_req: Request<Body>| async {
789 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
790 }),
791 start_reference,
792 max_connection_age,
793 peer_addr: "1.2.3.4:1234".parse().unwrap(),
794 };
795
796 let req = Request::get("http://example.com")
797 .body(Body::empty())
798 .unwrap();
799 let response = service.call(req).await.unwrap();
800 assert_eq!(response.headers().get("Connection"), None);
801
802 tokio::time::advance(Duration::from_millis(500)).await;
803 let req = Request::get("http://example.com")
804 .body(Body::empty())
805 .unwrap();
806 let response = service.call(req).await.unwrap();
807 assert_eq!(response.headers().get("Connection"), None);
808
809 tokio::time::advance(Duration::from_millis(500)).await;
810 let req = Request::get("http://example.com")
811 .body(Body::empty())
812 .unwrap();
813 let response = service.call(req).await.unwrap();
814 assert_eq!(
815 response.headers().get("Connection"),
816 Some(&HeaderValue::from_static("close"))
817 );
818 }
819
820 #[tokio::test]
821 async fn test_max_connection_age_service_http2() {
822 tokio::time::pause();
823
824 let start_reference = Instant::now();
825 let max_connection_age = Duration::from_secs(0);
826 let mut service = MaxConnectionAgeService {
827 service: tower::service_fn(|_req: Request<Body>| async {
828 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
829 }),
830 start_reference,
831 max_connection_age,
832 peer_addr: "1.2.3.4:1234".parse().unwrap(),
833 };
834
835 let mut req = Request::get("http://example.com")
836 .body(Body::empty())
837 .unwrap();
838 *req.version_mut() = Version::HTTP_2;
839 let response = service.call(req).await.unwrap();
840 assert_eq!(response.headers().get("Connection"), None);
841 }
842
843 #[tokio::test]
844 async fn test_max_connection_age_service_http3() {
845 tokio::time::pause();
846
847 let start_reference = Instant::now();
848 let max_connection_age = Duration::from_secs(0);
849 let mut service = MaxConnectionAgeService {
850 service: tower::service_fn(|_req: Request<Body>| async {
851 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
852 }),
853 start_reference,
854 max_connection_age,
855 peer_addr: "1.2.3.4:1234".parse().unwrap(),
856 };
857
858 let mut req = Request::get("http://example.com")
859 .body(Body::empty())
860 .unwrap();
861 *req.version_mut() = Version::HTTP_3;
862 let response = service.call(req).await.unwrap();
863 assert_eq!(response.headers().get("Connection"), None);
864 }
865
866 #[tokio::test]
867 async fn test_max_connection_age_service_zero_duration() {
868 tokio::time::pause();
869
870 let start_reference = Instant::now();
871 let max_connection_age = Duration::from_millis(0);
872 let mut service = MaxConnectionAgeService {
873 service: tower::service_fn(|_req: Request<Body>| async {
874 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
875 }),
876 start_reference,
877 max_connection_age,
878 peer_addr: "1.2.3.4:1234".parse().unwrap(),
879 };
880
881 let req = Request::get("http://example.com")
882 .body(Body::empty())
883 .unwrap();
884 let response = service.call(req).await.unwrap();
885 assert_eq!(
886 response.headers().get("Connection"),
887 Some(&HeaderValue::from_static("close"))
888 );
889 }
890
891 #[tokio::test]
895 async fn test_max_connection_age_service_with_hyper_server() {
896 let max_connection_age = Duration::from_secs(1);
898 let (_guard, addr) = next_addr();
899 let make_svc = make_service_fn(move |conn: &AddrStream| {
900 let svc = ServiceBuilder::new()
901 .layer(MaxConnectionAgeLayer::new(
902 max_connection_age,
903 0.,
904 conn.remote_addr(),
905 ))
906 .service(tower::service_fn(|_req: Request<Body>| async {
907 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
908 }));
909 futures_util::future::ok::<_, Infallible>(svc)
910 });
911
912 tokio::spawn(async move {
913 Server::bind(&addr).serve(make_svc).await.unwrap();
914 });
915
916 tokio::time::sleep(Duration::from_millis(10)).await;
918
919 let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
921
922 let req = Request::get(format!("http://{addr}/"))
925 .body(Body::empty())
926 .unwrap();
927 let response = client.send(req).await.unwrap();
928 assert_eq!(response.headers().get("Connection"), None);
929
930 let req = Request::get(format!("http://{addr}/"))
931 .body(Body::empty())
932 .unwrap();
933 let response = client.send(req).await.unwrap();
934 assert_eq!(response.headers().get("Connection"), None);
935
936 tokio::time::sleep(Duration::from_secs(1)).await;
939 let req = Request::get(format!("http://{addr}/"))
940 .body(Body::empty())
941 .unwrap();
942 let response = client.send(req).await.unwrap();
943 assert_eq!(
944 response.headers().get("Connection"),
945 Some(&HeaderValue::from_static("close")),
946 );
947
948 let req = Request::get(format!("http://{addr}/"))
952 .body(Body::empty())
953 .unwrap();
954 let response = client.send(req).await.unwrap();
955 assert_eq!(response.headers().get("Connection"), None);
956 }
957}