diff --git a/main.cpp b/main.cpp index c16ab2b..4b50114 100644 --- a/main.cpp +++ b/main.cpp @@ -1,6 +1,5 @@ #include "http.hpp" #include "router.hpp" -#include "tree.hpp" using namespace http; @@ -10,70 +9,35 @@ void HelloWorld(Request req, Response *res) { } int main() { - Tree t; + Router router(8181); - t.AddPath("/test/dummy", [](Request req, Response *res) { + // Allow all Methods + router.Handle("GET /helloWorld", HelloWorld); + router.Handle("GET /healthz", [](Request req, Response *res) { + res->SetStatusCode(statuscode::OK); + res->SetPayload(std::vector()); + res->SetContentType("text/plain"); + }); + + // Only allow GET + router.Handle("GET /echo/{name}", [](Request req, Response *res) { + std::string name = req.path.Get("name").value_or("No Name given"); + res->SetPayload("Hello " + name); + res->SetContentType("text/plain"); + }); + + // Only allow POST + router.Handle("POST /echo/{name}", [](Request req, Response *res) { + std::string name = req.path.Get("name").value_or("No Name given"); + res->SetPayload("Hello with Post" + name); + res->SetContentType("text/plain"); + }); + + router.Handle("GET /", [](Request req, Response *res) { res->SetPayload("Main"); res->SetContentType("text/plain"); }); - t.AddPath("/test/dummy/main", [](Request req, Response *res) { - res->SetPayload("Main"); - res->SetContentType("text/plain"); - }); - - t.AddPath("/test/dummy/main2", [](Request req, Response *res) { - res->SetPayload("Main"); - res->SetContentType("text/plain"); - }); - - t.AddPath("/test/dummy2/main", [](Request req, Response *res) { - res->SetPayload("Main"); - res->SetContentType("text/plain"); - }); - - t.AddPath("/var/main", [](Request req, Response *res) { - res->SetPayload("Main"); - res->SetContentType("text/plain"); - }); - - t.AddPath("/test/dummy2", [](Request req, Response *res) { - res->SetPayload("Main"); - res->SetContentType("text/plain"); - }); - - t.DebugPrint(); - - // Router router(8181); - // - // // Allow all Methods - // router.Handle("/helloWorld", HelloWorld); - // - // router.Handle("/healthz", [](Request req, Response *res) { - // res->SetStatusCode(statuscode::OK); - // res->SetPayload(std::vector()); - // res->SetContentType("text/plain"); - // }); - // - // // Only allow GET - // router.Handle("GET /echo/{name}", [](Request req, Response *res) { - // std::string name = req.path.Get("name").value_or("No Name given"); - // res->SetPayload("Hello " + name); - // res->SetContentType("text/plain"); - // }); - // - // // Only allow POST - // router.Handle("POST /echo/{name}", [](Request req, Response *res) { - // std::string name = req.path.Get("name").value_or("No Name given"); - // res->SetPayload("Hello with Post" + name); - // res->SetContentType("text/plain"); - // }); - // - // router.Handle("/", [](Request req, Response *res) { - // res->SetPayload("Main"); - // res->SetContentType("text/plain"); - // }); - // - // router.Start(); - // return 0; + router.Start(); + return 0; } diff --git a/router.cpp b/router.cpp index d3b8661..044023d 100644 --- a/router.cpp +++ b/router.cpp @@ -1,8 +1,10 @@ #include "router.hpp" #include "http.hpp" +#include "util.hpp" #include #include #include +#include #include #include #include @@ -93,52 +95,25 @@ void Router::ThreadLoop() { void Router::Handle(std::string pathPattern, std::function func) { - m_routes.insert_or_assign(pathPattern, func); + auto route = split(pathPattern, " "); + // TODO: UNSAFE CHECK BOUNDS + auto tree = m_routes[route[0]]; + if (!tree) { + tree = std::make_shared(Tree(route[0])); + m_routes.insert_or_assign(route[0], tree); + } + tree->AddPath(route[1], func); } -// This should be better -// Probably dont use map but a tree for it, then traverse tree for routing -// Also this isnt accurate Response Router::Route(Request req) { - for (const auto &[key, value] : m_routes) { - std::string pattern = key; + auto tree = m_routes[req.Method()]; + auto route = tree->Get(req.path.Base()); - int mPos = pattern.find(' '); - std::string method = pattern.substr(0, mPos); - - if (mPos != -1 && method != req.Method()) { - continue; - } - - pattern.erase(0, mPos + 1); - std::string patternCopy = pattern; - std::string path = req.path.Base(); - bool found = false; - int pos = 0; - while (pos != -1) { - found = true; - pos = pattern.find('/'); - std::string p = pattern.substr(0, pos); - - int uPos = path.find('/'); - std::string u = path.substr(0, uPos); - - if (!p.starts_with('{') && strcasecmp(p.data(), u.data()) != 0) { - found = false; - break; - } - - pattern.erase(0, pos + 1); - path.erase(0, uPos + 1); - } - - if (found) { - Response res(statuscode::OK); - req.path.Match(patternCopy); - value(req, &res); - return res; - } + if (!route.has_value()) { + return Response(statuscode::NOT_FOUND); } - return Response(statuscode::NOT_FOUND); + Response res(statuscode::OK); + route.value()(req, &res); + return res; } diff --git a/router.hpp b/router.hpp index bc31ed5..af5a148 100644 --- a/router.hpp +++ b/router.hpp @@ -3,9 +3,10 @@ #include "request.hpp" #include "response.hpp" +#include "tree.hpp" #include #include -#include +#include #include #include #include @@ -14,7 +15,7 @@ namespace http { class Router { private: - std::map> m_routes; + std::map> m_routes; int m_socket; sockaddr_in m_address; Response Route(Request req); diff --git a/tree.cpp b/tree.cpp index 1fcb3d2..6ddf8f4 100644 --- a/tree.cpp +++ b/tree.cpp @@ -1,31 +1,13 @@ #include "tree.hpp" +#include "util.hpp" #include #include #include #include +#include using namespace http; -std::vector split(std::string s, std::string delimiter) { - size_t pos_start = 0, pos_end, delim_len = delimiter.length(); - std::string token; - std::vector res; - - while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { - token = s.substr(pos_start, pos_end - pos_start); - pos_start = pos_end + delim_len; - if (token != "") { - res.push_back(token); - } - } - - token = s.substr(pos_start); - if (token != "") { - res.push_back(token); - } - return res; -} - Node::Node(std::string sub) { m_subPath = sub; m_isDummy = true; @@ -41,16 +23,7 @@ Node::Node(std::string sub, bool isValue, m_isDummy = false; } -void create_dummy_nodes(std::shared_ptr **root, - std::vector rest) { - auto curr = *root; - for (auto next : rest) { - auto dummy = std::make_shared(Node{next}); - (*curr)->m_next.insert_or_assign(next, dummy); - curr = &dummy; - } - root = &curr; -} +Tree::Tree(std::string method) { m_method = method; } void addNode(std::shared_ptr const &parent, std::string path, std::vector rest, @@ -113,4 +86,41 @@ void printNode(std::shared_ptr node, size_t depth, size_t max_depth) { } } +std::optional> +traverse(std::shared_ptr const &parent, std::string path, + std::vector rest) { + + std::shared_ptr curr = parent->m_next[path]; + if (rest.size() == 0) { + if (curr != nullptr && !curr->m_isDummy) + return curr->m_function; + else + return std::nullopt; + } + + if (curr) { + auto newPath = rest.front(); + // Ineffiecient, use deque + rest.erase(rest.begin()); + return traverse(curr, newPath, rest); + } + + return std::nullopt; +} + +std::optional> +Tree::Get(std::string path) { + auto subs = split(path, "/"); + if (subs.size() == 0) { + if (!m_root->m_isDummy) + return m_root->m_function; + else + return std::nullopt; + } + + auto newPath = subs.front(); + subs.erase(subs.begin()); + return traverse(m_root, newPath, subs); +} + void Tree::DebugPrint() { printNode(m_root, 0, 10); } diff --git a/tree.hpp b/tree.hpp index 0759878..5ee5011 100644 --- a/tree.hpp +++ b/tree.hpp @@ -27,10 +27,12 @@ class Tree { private: std::shared_ptr m_root; std::string m_method; - size_t depth; + size_t m_depth; public: + Tree(std::string method); void AddPath(std::string, std::function); + std::optional> Get(std::string); void DebugPrint(); }; } // namespace http diff --git a/util.hpp b/util.hpp new file mode 100644 index 0000000..65fab92 --- /dev/null +++ b/util.hpp @@ -0,0 +1,27 @@ +#ifndef UTIL_H +#define UTIL_H + +#include +#include + +inline std::vector split(std::string s, std::string delimiter) { + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + std::vector res; + + while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { + token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + delim_len; + if (token != "") { + res.push_back(token); + } + } + + token = s.substr(pos_start); + if (token != "") { + res.push_back(token); + } + return res; +} + +#endif