use hyper::{Body, StatusCode};
use log::trace;
use crate::helpers::http::PercentDecoded;
use crate::router::non_match::RouteNonMatch;
use crate::router::route::{Delegation, Route};
use crate::router::tree::segment::{SegmentMapping, SegmentType};
use crate::state::{request_id, State};
use std::cmp::Ordering;
use std::collections::HashMap;
pub struct Node {
segment: String,
segment_type: SegmentType,
routes: Vec<Box<dyn Route<ResBody = Body> + Send + Sync>>,
children: Vec<Node>,
}
impl Node {
pub fn new(segment: &str, segment_type: SegmentType) -> Self {
Node {
segment_type,
segment: segment.to_string(),
routes: vec![],
children: vec![],
}
}
pub fn add_child(&mut self, node: Node) -> &mut Self {
self.children.push(node);
self.children.sort();
self
}
pub fn add_route(&mut self, route: Box<dyn Route<ResBody = Body> + Send + Sync>) -> &mut Self {
self.routes.push(route);
self
}
pub fn borrow_child(&self, segment: &str, segment_type: SegmentType) -> Option<&Node> {
self.children
.iter()
.find(|n| n.segment_type == segment_type && n.segment == segment)
}
pub fn borrow_child_mut(
&mut self,
segment: &str,
segment_type: SegmentType,
) -> Option<&mut Node> {
self.children
.iter_mut()
.find(|n| n.segment_type == segment_type && n.segment == segment)
}
pub fn has_child(&self, segment: &str, segment_type: SegmentType) -> bool {
self.borrow_child(segment, segment_type).is_some()
}
pub fn is_routable(&self) -> bool {
!self.routes.is_empty()
}
pub fn match_node<'a>(
&'a self,
segments: &'a [PercentDecoded],
) -> Option<(&'a Node, SegmentMapping<'a>, usize)> {
let mut params = HashMap::new();
let mut processed = 0;
self.inner_match_node(segments, &mut params, &mut processed)
.map(|node| (node, params, processed))
}
pub fn segment<'a>(&'a self) -> &'a str {
&self.segment
}
pub fn select_route(
&self,
state: &State,
) -> Result<&Box<dyn Route<ResBody = Body> + Send + Sync>, RouteNonMatch> {
let mut err = Ok(());
for r in self.routes.iter() {
match r.is_match(state) {
Ok(()) => {
trace!("[{}] found matching route", request_id(state));
return Ok(r);
}
Err(e) => {
err = match err {
Err(e0) => Err(e.union(e0)),
Ok(()) => Err(e),
}
}
}
}
if let Err(e) = err {
trace!(
"[{}] no matching route, using error status code from route",
request_id(state)
);
return Err(e);
}
trace!(
"[{}] invalid state, no routes. sending internal server error",
request_id(state)
);
Err(RouteNonMatch::new(StatusCode::INTERNAL_SERVER_ERROR))
}
fn inner_match_node<'a>(
&'a self,
segments: &'a [PercentDecoded],
params: &mut SegmentMapping<'a>,
processed: &mut usize,
) -> Option<&'a Node> {
let next_segment = segments.split_first();
if next_segment.is_none() {
if !self.is_routable() {
return None;
}
return Some(self);
}
if let Some(route) = self.routes.first() {
if route.delegation() == Delegation::External {
return Some(self);
}
}
let (segment, remaining) = next_segment.unwrap();
*processed += 1;
for child in &self.children {
match child.segment_type {
SegmentType::Glob => {
params
.entry(&child.segment)
.or_insert_with(Vec::new)
.push(segment);
}
SegmentType::Static => {
if child.segment != segment.as_ref() {
continue;
}
}
SegmentType::Constrained { ref regex } => {
if !regex.is_match(segment.as_ref()) {
continue;
}
params.insert(&child.segment, vec![segment]);
}
SegmentType::Dynamic => {
params.insert(&child.segment, vec![segment]);
}
};
return child.inner_match_node(remaining, params, processed);
}
if let SegmentType::Glob = self.segment_type {
if let Some(path) = params.get_mut(self.segment()) {
path.push(segment);
}
return self.inner_match_node(remaining, params, processed);
}
None
}
}
impl Eq for Node {}
impl PartialEq for Node {
fn eq(&self, other: &Node) -> bool {
self.segment == other.segment && self.segment_type == other.segment_type
}
}
impl Ord for Node {
fn cmp(&self, other: &Node) -> Ordering {
(&self.segment_type, &self.segment).cmp(&(&other.segment_type, &other.segment))
}
}
impl PartialOrd for Node {
fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::panic::RefUnwindSafe;
use hyper::{HeaderMap, Method, Response};
use crate::extractor::{NoopPathExtractor, NoopQueryStringExtractor};
use crate::helpers::http::request::path::RequestPathSegments;
use crate::helpers::http::PercentDecoded;
use crate::pipeline::{finalize_pipeline_set, new_pipeline_set, PipelineSet};
use crate::router::route::dispatch::DispatcherImpl;
use crate::router::route::matcher::MethodOnlyRouteMatcher;
use crate::router::route::{Delegation, Extractors, Route, RouteImpl};
use crate::router::tree::regex::ConstrainedSegmentRegex;
use crate::state::{set_request_id, State};
fn handler(state: State) -> (State, Response<Body>) {
(state, Response::new(Body::empty()))
}
fn get_route<P>(pipeline_set: PipelineSet<P>) -> Box<dyn Route<ResBody = Body> + Send + Sync>
where
P: Send + Sync + RefUnwindSafe + 'static,
{
let methods = vec![Method::GET];
let matcher = MethodOnlyRouteMatcher::new(methods);
let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set);
let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
let route = RouteImpl::new(
matcher,
Box::new(dispatcher),
extractors,
Delegation::Internal,
);
Box::new(route)
}
fn test_structure() -> Node {
let mut root = Node::new("/", SegmentType::Static);
let pipeline_set = finalize_pipeline_set(new_pipeline_set());
let mut seg1 = Node::new("seg1", SegmentType::Static);
let methods = vec![Method::GET, Method::HEAD];
let matcher = MethodOnlyRouteMatcher::new(methods);
let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set.clone());
let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
let route = RouteImpl::new(
matcher,
Box::new(dispatcher),
extractors,
Delegation::Internal,
);
seg1.add_route(Box::new(route));
root.add_child(seg1);
let mut seg2 = Node::new("seg2", SegmentType::Static);
let methods = vec![Method::POST];
let matcher = MethodOnlyRouteMatcher::new(methods);
let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set.clone());
let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
let route = RouteImpl::new(
matcher,
Box::new(dispatcher),
extractors,
Delegation::Internal,
);
seg2.add_route(Box::new(route));
let methods = vec![Method::PATCH];
let matcher = MethodOnlyRouteMatcher::new(methods);
let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set.clone());
let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
let route = RouteImpl::new(
matcher,
Box::new(dispatcher),
extractors,
Delegation::Internal,
);
seg2.add_route(Box::new(route));
root.add_child(seg2);
let mut seg3 = Node::new("seg3", SegmentType::Static);
let mut seg4 = Node::new("seg4", SegmentType::Static);
seg4.add_route(get_route(pipeline_set.clone()));
seg3.add_child(seg4);
root.add_child(seg3);
let mut seg_resource = Node::new("resource", SegmentType::Static);
let mut seg_id = Node::new(
"id",
SegmentType::Constrained {
regex: Box::new(ConstrainedSegmentRegex::new("[0-9]+")),
},
);
seg_id.add_route(get_route(pipeline_set.clone()));
seg_resource.add_child(seg_id);
root.add_child(seg_resource);
let mut seg5 = Node::new("seg5", SegmentType::Static);
let mut seg6 = Node::new("seg6", SegmentType::Static);
seg6.add_route(get_route(pipeline_set.clone()));
let mut segdyn1 = Node::new(":segdyn1", SegmentType::Dynamic);
let mut seg7 = Node::new("seg7", SegmentType::Static);
seg7.add_route(get_route(pipeline_set.clone()));
let mut seg8 = Node::new("seg8", SegmentType::Glob);
let mut seg9 = Node::new("seg9", SegmentType::Static);
let mut seg10 = Node::new("seg10", SegmentType::Glob);
seg10.add_route(get_route(pipeline_set));
segdyn1.add_child(seg7);
seg5.add_child(seg6);
seg5.add_child(segdyn1);
root.add_child(seg5);
seg9.add_child(seg10);
seg8.add_child(seg9);
root.add_child(seg8);
root
}
#[test]
fn manages_children() {
let root = test_structure();
assert!(root.borrow_child("seg1", SegmentType::Static).is_some());
assert!(root.borrow_child("seg2", SegmentType::Static).is_some());
assert!(root.borrow_child("seg1", SegmentType::Dynamic).is_none());
assert!(root.borrow_child("seg0", SegmentType::Static).is_none());
}
#[test]
fn traverses_children() {
let root = test_structure();
let rs = RequestPathSegments::new("/seg3/seg4");
match root.match_node(rs.segments()) {
Some((node, _params, processed)) => {
assert_eq!(node.segment, "seg4");
assert_eq!(processed, 2);
}
None => panic!("traversal should have succeeded here"),
}
let rs = RequestPathSegments::new("/seg3/seg4/seg5");
assert!(root.match_node(rs.segments()).is_none());
let rs = RequestPathSegments::new("/seg5/seg6");
match root.match_node(rs.segments()) {
Some((node, _params, processed)) => {
assert_eq!(node.segment, "seg6");
assert_eq!(processed, 2);
}
None => panic!("traversal should have succeeded here"),
}
let rs = RequestPathSegments::new("/seg5/someval/seg7");
match root.match_node(rs.segments()) {
Some((node, _params, processed)) => {
assert_eq!(node.segment, "seg7");
assert_eq!(processed, 3);
}
None => panic!("traversal should have succeeded here"),
}
let rs = RequestPathSegments::new("/some/path/seg9/another/branch");
match root.match_node(rs.segments()) {
Some((node, _params, processed)) => {
assert_eq!(node.segment, "seg10");
assert_eq!(processed, 5);
}
None => panic!("traversal should have succeeded here"),
}
let rs = RequestPathSegments::new("/resource/5001");
let expected_segment = "id";
match root.match_node(rs.segments()) {
Some((node, _params, processed)) => {
assert_eq!(node.segment, expected_segment);
assert_eq!(processed, 2);
}
None => panic!("traversal should have succeeded here"),
}
}
#[test]
fn non_matching_routes_allow_list_tests() {
let root = test_structure();
let mut state = State::new();
state.put(Method::OPTIONS);
state.put(HeaderMap::new());
set_request_id(&mut state);
let rs = RequestPathSegments::new("/seg2");
match root.match_node(rs.segments()) {
Some((node, _params, _processed)) => match node.select_route(&state) {
Err(e) => {
let (status, mut allow_list) = e.deconstruct();
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
allow_list.sort_by(|a, b| a.as_ref().cmp(b.as_ref()));
assert_eq!(allow_list, vec![Method::PATCH, Method::POST]);
}
Ok(_) => panic!("expected mismatched route to test allow header"),
},
None => panic!("traversal should have succeeded here"),
}
let rs = RequestPathSegments::new("/resource/100");
match root.match_node(rs.segments()) {
Some((node, _params, _processed)) => match node.select_route(&state) {
Err(e) => {
let (status, mut allow_list) = e.deconstruct();
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
allow_list.sort_by(|a, b| a.as_ref().cmp(b.as_ref()));
assert_eq!(allow_list, vec![Method::GET]);
}
Ok(_) => panic!("expected mismatched route to test allow header"),
},
None => panic!("traversal should have succeeded here"),
}
}
#[test]
fn node_traversal_tests() {
let pipeline_set = finalize_pipeline_set(new_pipeline_set());
let mut root_node_builder = Node::new("/", SegmentType::Static);
let mut activate_node_builder = Node::new("activate", SegmentType::Static);
let mut workflow_node = Node::new("workflow", SegmentType::Static);
let route = {
let methods = vec![Method::GET];
let matcher = MethodOnlyRouteMatcher::new(methods);
let dispatcher = Box::new(DispatcherImpl::new(|| Ok(handler), (), pipeline_set));
let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> =
Extractors::new();
let route = RouteImpl::new(matcher, dispatcher, extractors, Delegation::Internal);
Box::new(route)
};
workflow_node.add_route(route);
activate_node_builder.add_child(workflow_node);
root_node_builder.add_child(activate_node_builder);
let root_node = root_node_builder;
match root_node.match_node(&[
PercentDecoded::new("activate").unwrap(),
PercentDecoded::new("workflow").unwrap(),
]) {
Some((node, _params, processed)) => {
assert!(node.is_routable());
assert_eq!(processed, 2)
}
None => panic!(),
}
}
}