diff options
| author | Fuwn <[email protected]> | 2026-01-21 08:52:06 +0000 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2026-01-21 08:52:06 +0000 |
| commit | 4a8034812bd24ea1d3b7ac1f1aaf3ea1faf85588 (patch) | |
| tree | 8157168f810d399d7f14dc06b1b15954cc9ddb40 | |
| parent | fix(router): Use warn! macro instead of println! for stream errors (diff) | |
| download | windmark-4a8034812bd24ea1d3b7ac1f1aaf3ea1faf85588.tar.xz windmark-4a8034812bd24ea1d3b7ac1f1aaf3ea1faf85588.zip | |
perf(router): Reduce per-connection overhead with shared RequestHandler
| -rw-r--r-- | src/router.rs | 561 |
1 files changed, 278 insertions, 283 deletions
diff --git a/src/router.rs b/src/router.rs index 495463c..b0e9d84 100644 --- a/src/router.rs +++ b/src/router.rs @@ -114,6 +114,268 @@ pub struct Router { listener_address: String, } +struct RequestHandler { + routes: matchit::Router<Arc<AsyncMutex<Box<dyn RouteResponse>>>>, + error_handler: Arc<AsyncMutex<Box<dyn ErrorResponse>>>, + headers: Arc<Mutex<Vec<Box<dyn Partial>>>>, + footers: Arc<Mutex<Vec<Box<dyn Partial>>>>, + pre_route_callback: Arc<Mutex<Box<dyn PreRouteHook>>>, + post_route_callback: Arc<Mutex<Box<dyn PostRouteHook>>>, + character_set: String, + languages: Vec<String>, + async_modules: Arc<AsyncMutex<Vec<Box<dyn AsyncModule + Send>>>>, + modules: Arc<Mutex<Vec<Box<dyn Module + Send>>>>, + options: HashSet<RouterOption>, +} + +impl RequestHandler { + #[allow( + clippy::too_many_lines, + clippy::significant_drop_in_scrutinee, + clippy::cognitive_complexity + )] + async fn handle(&self, stream: &mut Stream) -> Result<(), Box<dyn Error>> { + let mut buffer = [0u8; 1024]; + let mut url = Url::parse("gemini://fuwn.me/")?; + let mut footer = String::new(); + let mut header = String::new(); + + while let Ok(size) = stream.read(&mut buffer).await { + let request = or_error!( + stream, + std::str::from_utf8(&buffer[0..size]).map(ToString::to_string), + "59 The server (Windmark) received a bad request: {}" + ); + let request_trimmed = request + .find("\r\n") + .map_or(&request[..], |pos| &request[..pos]); + + url = or_error!( + stream, + Url::parse(request_trimmed), + "59 The server (Windmark) received a bad request: {}" + ); + + if request.contains("\r\n") { + break; + } + } + + if url.path().is_empty() { + url.set_path("/"); + } + + let mut path = url.path().to_string(); + + if self + .options + .contains(&RouterOption::AllowCaseInsensitiveLookup) + { + path = path.to_lowercase(); + } + + let mut route = self.routes.at(&path); + + if route.is_err() { + if self + .options + .contains(&RouterOption::RemoveExtraTrailingSlash) + && path.ends_with('/') + && path != "/" + { + let trimmed = path.trim_end_matches('/'); + + if trimmed != path { + path = trimmed.to_string(); + route = self.routes.at(&path); + } + } else if self + .options + .contains(&RouterOption::AddMissingTrailingSlash) + && !path.ends_with('/') + { + let mut path_with_slash = String::with_capacity(path.len() + 1); + + path_with_slash.push_str(&path); + path_with_slash.push('/'); + + if self.routes.at(&path_with_slash).is_ok() { + path = path_with_slash; + route = self.routes.at(&path); + } + } + } + + let peer_certificate = stream.ssl().peer_certificate(); + let url_clone = url.clone(); + let hook_context = HookContext::new( + stream.get_ref().peer_addr(), + url_clone.clone(), + route.as_ref().ok().map(|route| route.params.clone()), + peer_certificate.clone(), + ); + let hook_context_clone = hook_context.clone(); + + for module in &mut *self.async_modules.lock().await { + module.on_pre_route(hook_context_clone.clone()).await; + } + + let hook_context_clone = hook_context.clone(); + + if let Ok(mut modules) = self.modules.lock() { + for module in &mut *modules { + module.on_pre_route(hook_context_clone.clone()); + } + } + + if let Ok(mut callback) = self.pre_route_callback.lock() { + callback.call(hook_context.clone()); + } + + let mut content = if let Ok(ref route) = route { + let route_context = RouteContext::new( + stream.get_ref().peer_addr(), + url_clone, + &route.params, + peer_certificate, + ); + + { + let mut headers = self.headers.lock().expect("headers lock poisoned"); + + for partial_header in &mut *headers { + writeln!( + &mut header, + "{}", + partial_header.call(route_context.clone()), + ) + .expect("failed to write header"); + } + } + + { + let mut footers = self.footers.lock().expect("footers lock poisoned"); + let length = footers.len(); + + for (i, partial_footer) in footers.iter_mut().enumerate() { + let _ = write!( + &mut footer, + "{}{}", + partial_footer.call(route_context.clone()), + if length > 1 && i != length - 1 { + "\n" + } else { + "" + }, + ); + } + } + + let mut lock = (*route.value).lock().await; + let handler = lock.call(route_context); + + handler.await + } else { + (*self.error_handler) + .lock() + .await + .call(ErrorContext::new( + stream.get_ref().peer_addr(), + url_clone, + peer_certificate, + )) + .await + }; + + let hook_context_clone = hook_context.clone(); + + for module in &mut *self.async_modules.lock().await { + module.on_post_route(hook_context_clone.clone()).await; + } + + let hook_context_clone = hook_context.clone(); + + if let Ok(mut modules) = self.modules.lock() { + for module in &mut *modules { + module.on_post_route(hook_context_clone.clone()); + } + } + + if let Ok(mut callback) = self.post_route_callback.lock() { + callback.call(hook_context, &mut content); + } + + let status_code = + if content.status == 21 || content.status == 22 || content.status == 23 { + 20 + } else { + content.status + }; + let status_line = match content.status { + 20 => { + let mime = content.mime.as_deref().unwrap_or("text/gemini"); + let charset = content + .character_set + .as_deref() + .unwrap_or(&self.character_set); + let lang = content + .languages + .as_ref() + .map_or_else(|| self.languages.join(","), |l| l.join(",")); + + format!("{status_code} {mime}; charset={charset}; lang={lang}") + } + 21 => { + format!( + "{} {}", + status_code, + content.mime.as_deref().unwrap_or_default() + ) + } + #[cfg(feature = "auto-deduce-mime")] + 22 => { + format!( + "{} {}", + status_code, + content.mime.as_deref().unwrap_or_default() + ) + } + _ => { + format!("{} {}", status_code, content.content) + } + }; + let body = match content.status { + 20 => { + let mut body = String::with_capacity( + header.len() + content.content.len() + footer.len() + 1, + ); + + body.push_str(&header); + body.push_str(&content.content); + body.push('\n'); + body.push_str(&footer); + + body + } + 21 | 22 => content.content, + _ => String::new(), + }; + let mut response = + String::with_capacity(status_line.len() + body.len() + 2); + + response.push_str(&status_line); + response.push_str("\r\n"); + response.push_str(&body); + stream.write_all(response.as_bytes()).await?; + #[cfg(feature = "tokio")] + stream.shutdown().await?; + #[cfg(feature = "async-std")] + stream.get_mut().shutdown(std::net::Shutdown::Both)?; + + Ok(()) + } +} + impl Router { /// Create a new `Router` /// @@ -343,20 +605,24 @@ impl Router { #[cfg(feature = "logger")] info!("windmark is listening for connections"); + let handler = Arc::new(RequestHandler { + routes: self.routes.clone(), + error_handler: self.error_handler.clone(), + headers: self.headers.clone(), + footers: self.footers.clone(), + pre_route_callback: self.pre_route_callback.clone(), + post_route_callback: self.post_route_callback.clone(), + character_set: self.character_set.clone(), + languages: self.languages.clone(), + async_modules: self.async_modules.clone(), + modules: self.modules.clone(), + options: self.options.clone(), + }); + loop { match listener.accept().await { Ok((stream, _)) => { - let routes = self.routes.clone(); - let error_handler = self.error_handler.clone(); - let headers = self.headers.clone(); - let footers = self.footers.clone(); - let async_modules = self.async_modules.clone(); - let modules = self.modules.clone(); - let pre_route_callback = self.pre_route_callback.clone(); - let post_route_callback = self.post_route_callback.clone(); - let character_set = self.character_set.clone(); - let languages = self.languages.clone(); - let options = self.options.clone(); + let handler = Arc::clone(&handler); let acceptor = self.ssl_acceptor.clone(); #[cfg(feature = "tokio")] let spawner = tokio::spawn; @@ -384,32 +650,7 @@ impl Router { warn!("stream accept error: {e:?}"); } - let router_instance = Self { - routes, - error_handler, - private_key_file_name: String::new(), - private_key_content: None, - certificate_file_name: String::new(), - certificate_content: None, - headers, - footers, - ssl_acceptor: acceptor, - #[cfg(feature = "logger")] - default_logger: false, - #[cfg(feature = "logger")] - log_filter: String::new(), - pre_route_callback, - post_route_callback, - character_set, - languages, - port: 0, - async_modules, - modules, - options, - listener_address: String::new(), - }; - - if let Err(e) = router_instance.handle(&mut stream).await { + if let Err(e) = handler.handle(&mut stream).await { error!("handle error: {e}"); } } @@ -422,252 +663,6 @@ impl Router { } } - #[allow( - clippy::too_many_lines, - clippy::significant_drop_in_scrutinee, - clippy::cognitive_complexity - )] - async fn handle(&self, stream: &mut Stream) -> Result<(), Box<dyn Error>> { - let mut buffer = [0u8; 1024]; - let mut url = Url::parse("gemini://fuwn.me/")?; - let mut footer = String::new(); - let mut header = String::new(); - - while let Ok(size) = stream.read(&mut buffer).await { - let request = or_error!( - stream, - std::str::from_utf8(&buffer[0..size]).map(ToString::to_string), - "59 The server (Windmark) received a bad request: {}" - ); - let request_trimmed = request - .find("\r\n") - .map_or(&request[..], |pos| &request[..pos]); - - url = or_error!( - stream, - Url::parse(request_trimmed), - "59 The server (Windmark) received a bad request: {}" - ); - - if request.contains("\r\n") { - break; - } - } - - if url.path().is_empty() { - url.set_path("/"); - } - - let mut path = url.path().to_string(); - - if self - .options - .contains(&RouterOption::AllowCaseInsensitiveLookup) - { - path = path.to_lowercase(); - } - - let mut route = self.routes.at(&path); - - if route.is_err() { - if self - .options - .contains(&RouterOption::RemoveExtraTrailingSlash) - && path.ends_with('/') - && path != "/" - { - let trimmed = path.trim_end_matches('/'); - - if trimmed != path { - path = trimmed.to_string(); - route = self.routes.at(&path); - } - } else if self - .options - .contains(&RouterOption::AddMissingTrailingSlash) - && !path.ends_with('/') - { - let mut path_with_slash = String::with_capacity(path.len() + 1); - - path_with_slash.push_str(&path); - path_with_slash.push('/'); - - if self.routes.at(&path_with_slash).is_ok() { - path = path_with_slash; - route = self.routes.at(&path); - } - } - } - - let peer_certificate = stream.ssl().peer_certificate(); - let url_clone = url.clone(); - let hook_context = HookContext::new( - stream.get_ref().peer_addr(), - url_clone.clone(), - route.as_ref().ok().map(|route| route.params.clone()), - peer_certificate.clone(), - ); - let hook_context_clone = hook_context.clone(); - - for module in &mut *self.async_modules.lock().await { - module.on_pre_route(hook_context_clone.clone()).await; - } - - let hook_context_clone = hook_context.clone(); - - if let Ok(mut modules) = self.modules.lock() { - for module in &mut *modules { - module.on_pre_route(hook_context_clone.clone()); - } - } - - if let Ok(mut callback) = self.pre_route_callback.lock() { - callback.call(hook_context.clone()); - } - - let mut content = if let Ok(ref route) = route { - let route_context = RouteContext::new( - stream.get_ref().peer_addr(), - url_clone, - &route.params, - peer_certificate, - ); - - { - let mut headers = self.headers.lock().expect("headers lock poisoned"); - - for partial_header in &mut *headers { - writeln!( - &mut header, - "{}", - partial_header.call(route_context.clone()), - ) - .expect("failed to write header"); - } - } - - { - let mut footers = self.footers.lock().expect("footers lock poisoned"); - let length = footers.len(); - - for (i, partial_footer) in footers.iter_mut().enumerate() { - let _ = write!( - &mut footer, - "{}{}", - partial_footer.call(route_context.clone()), - if length > 1 && i != length - 1 { - "\n" - } else { - "" - }, - ); - } - } - - let mut lock = (*route.value).lock().await; - let handler = lock.call(route_context); - - handler.await - } else { - (*self.error_handler) - .lock() - .await - .call(ErrorContext::new( - stream.get_ref().peer_addr(), - url_clone, - peer_certificate, - )) - .await - }; - - let hook_context_clone = hook_context.clone(); - - for module in &mut *self.async_modules.lock().await { - module.on_post_route(hook_context_clone.clone()).await; - } - - let hook_context_clone = hook_context.clone(); - - if let Ok(mut modules) = self.modules.lock() { - for module in &mut *modules { - module.on_post_route(hook_context_clone.clone()); - } - } - - if let Ok(mut callback) = self.post_route_callback.lock() { - callback.call(hook_context, &mut content); - } - - let status_code = - if content.status == 21 || content.status == 22 || content.status == 23 { - 20 - } else { - content.status - }; - let status_line = match content.status { - 20 => { - let mime = content.mime.as_deref().unwrap_or("text/gemini"); - let charset = content - .character_set - .as_deref() - .unwrap_or(&self.character_set); - let lang = content - .languages - .as_ref() - .map_or_else(|| self.languages.join(","), |l| l.join(",")); - - format!("{status_code} {mime}; charset={charset}; lang={lang}") - } - 21 => { - format!( - "{} {}", - status_code, - content.mime.as_deref().unwrap_or_default() - ) - } - #[cfg(feature = "auto-deduce-mime")] - 22 => { - format!( - "{} {}", - status_code, - content.mime.as_deref().unwrap_or_default() - ) - } - _ => { - format!("{} {}", status_code, content.content) - } - }; - let body = match content.status { - 20 => { - let mut body = String::with_capacity( - header.len() + content.content.len() + footer.len() + 1, - ); - - body.push_str(&header); - body.push_str(&content.content); - body.push('\n'); - body.push_str(&footer); - - body - } - 21 | 22 => content.content, - _ => String::new(), - }; - let mut response = - String::with_capacity(status_line.len() + body.len() + 2); - - response.push_str(&status_line); - response.push_str("\r\n"); - response.push_str(&body); - stream.write_all(response.as_bytes()).await?; - #[cfg(feature = "tokio")] - stream.shutdown().await?; - #[cfg(feature = "async-std")] - stream.get_mut().shutdown(std::net::Shutdown::Both)?; - - Ok(()) - } - fn create_acceptor(&mut self) -> Result<(), Box<dyn Error>> { let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; |