diff --git a/httpd.cc b/httpd.cc index 6660f5b7c0..ccd78b1547 100644 --- a/httpd.cc +++ b/httpd.cc @@ -7,13 +7,18 @@ #include #include #include +#include sstring to_sstring(const std::csub_match& sm) { return sstring(sm.first, sm.second); } +static std::string tchar = "[-!#$%&'\\*\\+.^_`|~0-9A-Za-z]"; +static std::string token = tchar + "+"; static constexpr auto re_opt = std::regex::ECMAScript | std::regex::optimize; static std::regex start_line_re { "([A-Z]+) (\\S+) HTTP/([0-9]\\.[0-9])\\r\\n", re_opt }; +static std::regex header_re { "(" + token + ")\\s*:\\s*([.*\\S])\\s*\\r\\n", re_opt }; +static std::regex header_cont_re { "\\s+(.*\\S)\\s*\\r\\n", re_opt }; class http_server { std::vector> _listeners; @@ -40,6 +45,12 @@ public: input_stream_buffer _read_buf; static constexpr size_t limit = 4096; using tmp_buf = temporary_buffer; + sstring _method; + sstring _url; + sstring _version; + sstring _response; + sstring _last_header_name; + std::unordered_map _headers; public: connection(http_server& server, std::unique_ptr&& fd, socket_address addr) : _server(server), _fd(std::move(fd)), _addr(addr), _read_buf(*_fd, 8192) {} @@ -49,13 +60,46 @@ public: std::cmatch match; if (!std::regex_match(start_line.begin(), start_line.end(), match, start_line_re)) { std::cout << "no match\n"; - delete this; - return; + return bad(); } - sstring method = to_sstring(match[1]); - sstring url = to_sstring(match[2]); - sstring version = to_sstring(match[3]); - std::cout << "start line: " << method << " | " << url << " | " << version << "\n"; + _method = to_sstring(match[1]); + _url = to_sstring(match[2]); + _version = to_sstring(match[3]); + if (_method != "GET") { + return bad(); + } + std::cout << "start line: " << _method << " | " << _url << " | " << _version << "\n"; + _read_buf.read_until(limit, '\n').then([this] (future header) { + parse_header(std::move(header)); + }); + }); + } + void parse_header(future f_header) { + auto header = f_header.get(); + if (header.size() == 2 && header[0] == '\r' && header[1] == '\n') { + return; + } + std::cmatch match; + if (std::regex_match(header.begin(), header.end(), match, header_re)) { + sstring name = to_sstring(match[1]); + sstring value = to_sstring(match[2]); + std::cout << "found header: " << name << "=" << value << ".\n"; + _headers[name] = std::move(value); + _last_header_name = std::move(name); + } else if (std::regex_match(header.begin(), header.end(), match, header_cont_re)) { + _headers[_last_header_name] += " "; + _headers[_last_header_name] += to_sstring(match[1]); + } else { + return bad(); + } + _read_buf.read_until(limit, '\n').then([this] (future header) { + parse_header(std::move(header)); + }); + } + void bad() { + _response = "400 BAD REQUEST\r\n\r\n"; + _fd->write_all(_response.begin(), _response.size()).then([this] (future n) mutable { + delete this; }); } }; diff --git a/reactor.hh b/reactor.hh index 10bb811c19..8da3404221 100644 --- a/reactor.hh +++ b/reactor.hh @@ -199,10 +199,10 @@ public: } template - void then(Func, Enable); + void then(Func, Enable) &&; template - void then(Func&& func, std::enable_if_t, void>::value, void*> = nullptr) { + void then(Func&& func, std::enable_if_t, void>::value, void*> = nullptr) && { auto state = _state; state->schedule([fut = std::move(*this), func = std::forward(func)] () mutable { func(std::move(fut)); diff --git a/sstring.hh b/sstring.hh index 4f2e27d01a..fd7ef72c52 100644 --- a/sstring.hh +++ b/sstring.hh @@ -15,6 +15,7 @@ #include #include #include +#include template class basic_sstring { @@ -43,6 +44,7 @@ class basic_sstring { char* str() { return is_internal() ? u.internal.str : u.external.str; } + struct initialized_later {}; public: basic_sstring() noexcept { u.internal.size = 0; @@ -54,7 +56,7 @@ public: } else { u.internal.size = -1; u.external.str = new char[x.u.external.size + 1]; - std::copy(x.u.str, x.u.str + x.u.extenal.size + 1, u.external.str); + std::copy(x.u.external.str, x.u.external.str + x.u.external.size + 1, u.external.str); u.external.size = x.u.external.size; } } @@ -63,6 +65,20 @@ public: x.u.internal.size = 0; x.u.internal.str[0] = '\0'; } + basic_sstring(initialized_later, size_t size) { + if (size_type(size) != size) { + throw std::overflow_error("sstring overflow"); + } + if (size + 1 <= sizeof(u.internal.str)) { + u.internal.str[size] = '\0'; + u.internal.size = size; + } else { + u.internal.size = -1; + u.external.str = new char[size + 1]; + u.external.size = size; + u.external.str[size] = '\0'; + } + } basic_sstring(const char_type* x, size_t size) { if (size_type(size) != size) { throw std::overflow_error("sstring overflow"); @@ -93,8 +109,11 @@ public: swap(tmp); } basic_sstring& operator=(basic_sstring&& x) noexcept { - reset(); - swap(x); + if (this != &x) { + swap(x); + x.reset(); + } + return *this; } operator std::string() const { return str(); @@ -120,6 +139,25 @@ public: const char* c_str() const { return str(); } + const char_type* begin() const { return str(); } + const char_type* end() const { return str() + size(); } + char_type* begin() { return str(); } + char_type* end() { return str() + size(); } + bool operator==(const basic_sstring& x) const { + return size() == x.size() && std::equal(begin(), end(), x.begin()); + } + bool operator!=(const basic_sstring& x) const { + return !operator==(x); + } + basic_sstring operator+(const basic_sstring& x) const { + basic_sstring ret(initialized_later(), size() + x.size()); + std::copy(begin(), end(), ret.begin()); + std::copy(x.begin(), x.end(), ret.begin() + size()); + return ret; + } + basic_sstring& operator+=(const basic_sstring& x) { + return *this = *this + x; + } }; template @@ -136,6 +174,22 @@ operator<<(std::basic_ostream& os, const basic_sstring +struct hash> { + size_t operator()(const basic_sstring& s) const { + size_t ret = 0; + for (auto c : s) { + ret = (ret << 6) | (ret >> (sizeof(ret) * 8 - 6)); + ret ^= c; + } + return ret; + } +}; + +} + using sstring = basic_sstring; #endif /* SSTRING_HH_ */