diff --git a/apps/memcache/ascii.rl b/apps/memcache/ascii.rl index edfd4e1c40..9941d33ab8 100644 --- a/apps/memcache/ascii.rl +++ b/apps/memcache/ascii.rl @@ -37,17 +37,30 @@ action advance_blob { crlf = '\r\n'; sp = ' '; u32 = digit+ >{ _u32 = 0; } ${ _u32 *= 10; _u32 += fc - '0'; }; +u64 = digit+ >{ _u64 = 0; } ${ _u64 *= 10; _u64 += fc - '0'; }; key = [^ ]+ >mark %{ _key = str(); }; flags = u32 %{ _flags = _u32; }; expiration = u32 %{ _expiration = _u32; }; size = u32 %{ _size = _u32; }; blob := any+ >start_blob $advance_blob; +maybe_noreply = (sp "noreply" @{ _noreply = true; })? >{ _noreply = false; }; +maybe_expiration = (sp expiration)? >{ _expiration = 0; }; +version_field = u64 %{ _version = _u64; }; -set = "set" sp key sp flags sp expiration sp size (crlf @{ fcall blob; } ) crlf @{ _state = state::cmd_set; }; +insertion_params = sp key sp flags sp expiration sp size maybe_noreply (crlf @{ fcall blob; } ) crlf; +set = "set" insertion_params @{ _state = state::cmd_set; }; +add = "add" insertion_params @{ _state = state::cmd_add; }; +replace = "replace" insertion_params @{ _state = state::cmd_replace; }; +cas = "cas" sp key sp flags sp expiration sp size sp version_field maybe_noreply (crlf @{ fcall blob; } ) crlf @{ _state = state::cmd_cas; }; get = "get" (sp key %{ _keys.push_back(std::move(_key)); })+ crlf @{ _state = state::cmd_get; }; -delete = "delete" sp key crlf @{ _state = state::cmd_delete; }; - -main := (set | get | delete) >eof{ _state = state::eof; }; +gets = "gets" (sp key %{ _keys.push_back(std::move(_key)); })+ crlf @{ _state = state::cmd_gets; }; +delete = "delete" sp key maybe_noreply crlf @{ _state = state::cmd_delete; }; +flush = "flush_all" maybe_expiration maybe_noreply crlf @{ _state = state::cmd_flush_all; }; +version = "version" crlf @{ _state = state::cmd_version; }; +stats = "stats" crlf @{ _state = state::cmd_stats; }; +incr = "incr" sp key sp u64 maybe_noreply crlf @{ _state = state::cmd_incr; }; +decr = "decr" sp key sp u64 maybe_noreply crlf @{ _state = state::cmd_decr; }; +main := (add | replace | set | get | gets | delete | flush | version | cas | stats | incr | decr) >eof{ _state = state::eof; }; prepush { prepush(); @@ -66,17 +79,29 @@ public: error, eof, cmd_set, + cmd_cas, + cmd_add, + cmd_replace, cmd_get, - cmd_delete + cmd_gets, + cmd_delete, + cmd_flush_all, + cmd_version, + cmd_stats, + cmd_incr, + cmd_decr, }; state _state; uint32_t _u32; + uint64_t _u64; sstring _key; uint32_t _flags; uint32_t _expiration; uint32_t _size; uint32_t _size_left; + uint64_t _version; sstring _blob; + bool _noreply; std::vector _keys; public: void init() { diff --git a/apps/memcache/memcache.cc b/apps/memcache/memcache.cc index c6cb9aefe9..e4f6f1b4af 100644 --- a/apps/memcache/memcache.cc +++ b/apps/memcache/memcache.cc @@ -1,4 +1,5 @@ #include +#include #include #include #include "core/app-template.hh" @@ -10,6 +11,9 @@ #include "net/api.hh" #include "net/packet-data-source.hh" #include "apps/memcache/ascii.hh" +#include + +#define VERSION_STRING "seastar v1.0" using namespace net; @@ -26,6 +30,26 @@ struct item_data { sstring _data; uint32_t _flag; clock_type::time_point _expiry; + + optional as_integral() { + auto str = _data.c_str(); + if (str[0] == '-') { + return {}; + } + + auto len = _data.size(); + + // Strip trailing space + while (len && str[len - 1] == ' ') { + len--; + } + + try { + return {boost::lexical_cast(str, len)}; + } catch (const boost::bad_lexical_cast& e) { + return {}; + } + } }; class item { @@ -37,25 +61,27 @@ private: bi::list_member_hook<> _expired_link; friend class cache; public: - item(item_data data) + item(item_data data, uint64_t version = 1) : _data(std::move(data)) - , _version(1) + , _version(version) , _expired(false) { } + item(const item&) = delete; + item(item&&) = delete; + clock_type::time_point get_timeout() { return _data._expiry; } - void update(item_data&& data) { - _data = std::move(data); - _version++; - } - item_data& data() { return _data; } + + uint64_t version() { + return _version; + } }; struct cache_stats { @@ -63,6 +89,19 @@ struct cache_stats { size_t _get_misses {}; size_t _set_adds {}; size_t _set_replaces {}; + size_t _cas_hits {}; + size_t _cas_misses {}; + size_t _cas_badval {}; + size_t _delete_misses {}; + size_t _delete_hits {}; + size_t _incr_misses {}; + size_t _incr_hits {}; + size_t _decr_misses {}; + size_t _decr_hits {}; +}; + +enum class cas_result { + not_found, stored, bad_version }; class cache { @@ -71,15 +110,24 @@ private: using cache_iterator = typename cache_type::iterator; cache_type _cache; timer_set _alive; - bi::list, &item::_expired_link>> _expired; + + // Contains items which are present in _cache but have expired + bi::list, &item::_expired_link>, + bi::constant_time_size> _expired; + timer _timer; cache_stats _stats; + timer _flush_timer; private: + void expire_now(item& it) { + it._expired = true; + _expired.push_back(it); + } + void expire() { _alive.expire(clock_type::now()); while (auto item = _alive.pop_expired()) { - item->_expired = true; - _expired.push_back(*item); + expire_now(*item); } _timer.arm(_alive.get_next_timeout()); } @@ -102,9 +150,10 @@ private: void add_overriding(cache_iterator i, item_data&& data) { auto& item_ref = *i->second; _alive.remove(item_ref); - item_ref.update(std::move(data)); - if (_alive.insert(item_ref)) { - _timer.rearm(item_ref.get_timeout()); + i->second = make_shared(std::move(data), item_ref._version + 1); + auto& new_ref = *i->second; + if (_alive.insert(new_ref)) { + _timer.rearm(new_ref.get_timeout()); } } @@ -120,6 +169,22 @@ private: public: cache() { _timer.set_callback([this] { expire(); }); + _flush_timer.set_callback([this] { flush_all(); }); + } + + void flush_all() { + _flush_timer.cancel(); + for (auto pair : _cache) { + auto& it = *pair.second; + if (!it._expired) { + _alive.remove(it); + expire_now(it); + } + } + } + + void flush_at(clock_type::time_point time_point) { + _flush_timer.rearm(time_point); } bool set(item_key&& key, item_data data) { @@ -141,6 +206,7 @@ public: return false; } + _stats._set_adds++; add_new(std::move(key), std::move(data)); return true; } @@ -151,6 +217,7 @@ public: return false; } + _stats._set_replaces++; add_overriding(i, std::move(data)); return true; } @@ -158,8 +225,10 @@ public: bool remove(const item_key& key) { auto i = find(key); if (i == _cache.end()) { + _stats._delete_misses++; return false; } + _stats._delete_hits++; auto& item_ref = *i->second; _alive.remove(item_ref); _cache.erase(i); @@ -176,30 +245,214 @@ public: return i->second; } + cas_result cas(const item_key& key, uint64_t version, item_data&& data) { + auto i = find(key); + if (i == _cache.end()) { + _stats._cas_misses++; + return cas_result::not_found; + } + auto& item_ref = *i->second; + if (item_ref._version != version) { + _stats._cas_badval++; + return cas_result::bad_version; + } + _stats._cas_hits++; + add_overriding(i, std::move(data)); + return cas_result::stored; + } + size_t size() { - return _cache.size(); + return _cache.size() - _expired.size(); } cache_stats& stats() { return _stats; } + + std::pair, bool> incr(const item_key& key, uint64_t delta) { + auto i = find(key); + if (i == _cache.end()) { + _stats._incr_misses++; + return {{}, false}; + } + auto& item_ref = *i->second; + _stats._incr_hits++; + auto value = item_ref._data.as_integral(); + if (!value) { + return {i->second, false}; + } + add_overriding(i, item_data{to_sstring(*value + delta), item_ref.data()._flag, item_ref.data()._expiry}); + return {i->second, true}; + } + + std::pair, bool> decr(const item_key& key, uint64_t delta) { + auto i = find(key); + if (i == _cache.end()) { + _stats._decr_misses++; + return {{}, false}; + } + auto& item_ref = *i->second; + _stats._decr_hits++; + auto value = item_ref._data.as_integral(); + if (!value) { + return {i->second, false}; + } + add_overriding(i, item_data{to_sstring(*value - std::min(*value, delta)), item_ref.data()._flag, item_ref.data()._expiry}); + return {i->second, true}; + } +}; + +struct system_stats { + uint32_t _curr_connections {}; + uint32_t _total_connections {}; + uint64_t _cmd_get {}; + uint64_t _cmd_set {}; + uint64_t _cmd_flush {}; + clock_type::time_point _start_time; }; class ascii_protocol { private: cache& _cache; + system_stats& _system_stats; memcache_ascii_parser _parser; private: static constexpr uint32_t seconds_in_a_month = 60 * 60 * 24 * 30; static constexpr const char *msg_crlf = "\r\n"; static constexpr const char *msg_error = "ERROR\r\n"; static constexpr const char *msg_stored = "STORED\r\n"; + static constexpr const char *msg_not_stored = "NOT_STORED\r\n"; static constexpr const char *msg_end = "END\r\n"; static constexpr const char *msg_value = "VALUE "; static constexpr const char *msg_deleted = "DELETED\r\n"; static constexpr const char *msg_not_found = "NOT_FOUND\r\n"; + static constexpr const char *msg_ok = "OK\r\n"; + static constexpr const char *msg_version = "VERSION " VERSION_STRING "\r\n"; + static constexpr const char *msg_exists = "EXISTS\r\n"; + static constexpr const char *msg_stat = "STAT "; + static constexpr const char *msg_error_non_numeric_value = "CLIENT_ERROR cannot increment or decrement non-numeric value\r\n"; +private: + template + future<> handle_get(output_stream& out) { + _system_stats._cmd_get++; + auto keys_p = make_shared>(std::move(_parser._keys)); + return do_for_each(keys_p->begin(), keys_p->end(), [this, &out, keys_p](auto&& key) mutable { + auto item = _cache.get(key); + if (!item) { + return make_ready_future<>(); + } + return out.write(msg_value) + .then([&out, &key] { + return out.write(key); + }).then([&out] { + return out.write(" "); + }).then([&out, item] { + return out.write(to_sstring(item->data()._flag)); + }).then([&out] { + return out.write(" "); + }).then([&out, item] { + return out.write(to_sstring(item->data()._data.size())); + }).then([&out, item] { + if (SendCasVersion) { + return out.write(" ").then([&out, item] { + return out.write(to_sstring(item->version())).then([&out] { + return out.write(msg_crlf); + }); + }); + } else { + return out.write(msg_crlf); + } + }).then([&out, item] { + return out.write(item->data()._data); + }).then([&out] { + return out.write(msg_crlf); + }); + }).then([&out] { + return out.write(msg_end); + }); + } + + template + static future<> print_stat(output_stream& out, const char* key, Value value) { + return out.write(msg_stat) + .then([&out, key] { return out.write(key); }) + .then([&out] { return out.write(" "); }) + .then([&out, value] { return out.write(to_sstring(value)); }) + .then([&out] { return out.write(msg_crlf); }); + } + + future<> print_stats(output_stream& out) { + auto now = clock_type::now(); + return print_stat(out, "pid", getpid()) + .then([this, now, &out] { + return print_stat(out, "uptime", + std::chrono::duration_cast(now - _system_stats._start_time).count()); + }).then([this, now, &out] { + return print_stat(out, "time", + std::chrono::duration_cast(now.time_since_epoch()).count()); + }).then([this, &out] { + return print_stat(out, "version", VERSION_STRING); + }).then([this, &out] { + return print_stat(out, "pointer_size", sizeof(void*)*8); + }).then([this, &out, v = _system_stats._curr_connections] { + return print_stat(out, "curr_connections", v); + }).then([this, &out, v = _system_stats._total_connections] { + return print_stat(out, "total_connections", v); + }).then([this, &out, v = _system_stats._curr_connections] { + return print_stat(out, "connection_structures", v); + }).then([this, &out, v = _system_stats._cmd_get] { + return print_stat(out, "cmd_get", v); + }).then([this, &out, v = _system_stats._cmd_set] { + return print_stat(out, "cmd_set", v); + }).then([this, &out, v = _system_stats._cmd_flush] { + return print_stat(out, "cmd_flush", v); + }).then([this, &out] { + return print_stat(out, "cmd_touch", 0); + }).then([this, &out, v = _cache.stats()._get_hits] { + return print_stat(out, "get_hits", v); + }).then([this, &out, v = _cache.stats()._get_misses] { + return print_stat(out, "get_misses", v); + }).then([this, &out, v = _cache.stats()._delete_misses] { + return print_stat(out, "delete_misses", v); + }).then([this, &out, v = _cache.stats()._delete_hits] { + return print_stat(out, "delete_hits", v); + }).then([this, &out, v = _cache.stats()._incr_misses] { + return print_stat(out, "incr_misses", v); + }).then([this, &out, v = _cache.stats()._incr_hits] { + return print_stat(out, "incr_hits", v); + }).then([this, &out, v = _cache.stats()._decr_misses] { + return print_stat(out, "decr_misses", v); + }).then([this, &out, v = _cache.stats()._decr_hits] { + return print_stat(out, "decr_hits", v); + }).then([this, &out, v = _cache.stats()._cas_misses] { + return print_stat(out, "cas_misses", v); + }).then([this, &out, v = _cache.stats()._cas_hits] { + return print_stat(out, "cas_hits", v); + }).then([this, &out, v = _cache.stats()._cas_badval] { + return print_stat(out, "cas_badval", v); + }).then([this, &out] { + return print_stat(out, "touch_hits", 0); + }).then([this, &out] { + return print_stat(out, "touch_misses", 0); + }).then([this, &out] { + return print_stat(out, "auth_cmds", 0); + }).then([this, &out] { + return print_stat(out, "auth_errors", 0); + }).then([this, &out] { + return print_stat(out, "threads", smp::count); + }).then([this, &out, v = _cache.size()] { + return print_stat(out, "curr_items", v); + }).then([this, &out, v = (_cache.stats()._set_replaces + _cache.stats()._set_adds + _cache.stats()._cas_hits)] { + return print_stat(out, "total_items", v); + }).then([&out] { + return out.write("END\r\n"); + }); + } public: - ascii_protocol(cache& cache) : _cache(cache) {} + ascii_protocol(cache& cache, system_stats& system_stats) + : _cache(cache) + , _system_stats(system_stats) + {} clock_type::time_point seconds_to_time_point(uint32_t seconds) { if (seconds == 0) { @@ -222,46 +475,124 @@ public: return out.write(msg_error); case memcache_ascii_parser::state::cmd_set: + _system_stats._cmd_set++; _cache.set(std::move(_parser._key), item_data{std::move(_parser._blob), _parser._flags, seconds_to_time_point(_parser._expiration)}); + if (_parser._noreply) { + return make_ready_future<>(); + } return out.write(msg_stored); - case memcache_ascii_parser::state::cmd_get: + case memcache_ascii_parser::state::cmd_cas: { - auto keys_p = make_shared>(std::move(_parser._keys)); - return do_for_each(keys_p->begin(), keys_p->end(), [this, &out, keys_p](auto&& key) mutable { - auto item = _cache.get(key); - if (!item) { - return make_ready_future<>(); - } - return out.write(msg_value) - .then([&out, &key] { - return out.write(key); - }).then([&out] { - return out.write(" "); - }).then([&out, item] { - return out.write(to_sstring(item->data()._flag)); - }).then([&out] { - return out.write(" "); - }).then([&out, item] { - return out.write(to_sstring(item->data()._data.size())); - }).then([&out] { - return out.write(msg_crlf); - }).then([&out, item] { - return out.write(item->data()._data); - }).then([&out] { - return out.write(msg_crlf); - }); - }).then([&out] { - return out.write(msg_end); + _system_stats._cmd_set++; + auto result = _cache.cas(_parser._key, _parser._version, + item_data{std::move(_parser._blob), _parser._flags, seconds_to_time_point(_parser._expiration)}); + if (_parser._noreply) { + return make_ready_future<>(); + } + switch (result) { + case cas_result::stored: + return out.write(msg_stored); + case cas_result::not_found: + return out.write(msg_not_found); + case cas_result::bad_version: + return out.write(msg_exists); + } + } + + case memcache_ascii_parser::state::cmd_add: + { + _system_stats._cmd_set++; + auto added = _cache.add(std::move(_parser._key), + item_data{std::move(_parser._blob), _parser._flags, seconds_to_time_point(_parser._expiration)}); + if (_parser._noreply) { + return make_ready_future<>(); + } + return out.write(added ? msg_stored : msg_not_stored); + } + + case memcache_ascii_parser::state::cmd_replace: + { + _system_stats._cmd_set++; + auto replaced = _cache.replace(std::move(_parser._key), + item_data{std::move(_parser._blob), _parser._flags, seconds_to_time_point(_parser._expiration)}); + if (_parser._noreply) { + return make_ready_future<>(); + } + return out.write(replaced ? msg_stored : msg_not_stored); + } + + case memcache_ascii_parser::state::cmd_get: + return handle_get(out); + + case memcache_ascii_parser::state::cmd_gets: + return handle_get(out); + + case memcache_ascii_parser::state::cmd_delete: + { + auto removed = _cache.remove(_parser._key); + if (_parser._noreply) { + return make_ready_future<>(); + } + return out.write(removed ? msg_deleted : msg_not_found); + } + + case memcache_ascii_parser::state::cmd_flush_all: + _system_stats._cmd_flush++; + if (_parser._expiration) { + _cache.flush_at(seconds_to_time_point(_parser._expiration)); + } else { + _cache.flush_all(); + } + if (_parser._noreply) { + return make_ready_future<>(); + } + return out.write(msg_ok); + + case memcache_ascii_parser::state::cmd_version: + return out.write(msg_version); + + case memcache_ascii_parser::state::cmd_stats: + return print_stats(out); + + case memcache_ascii_parser::state::cmd_incr: + { + auto result = _cache.incr(_parser._key, _parser._u64); + if (_parser._noreply) { + return make_ready_future<>(); + } + auto item = result.first; + if (!item) { + return out.write(msg_not_found); + } + auto incremented = result.second; + if (!incremented) { + return out.write(msg_error_non_numeric_value); + } + return out.write(item->data()._data).then([&out] { + return out.write(msg_crlf); }); } - case memcache_ascii_parser::state::cmd_delete: - if (_cache.remove(_parser._key)) { - return out.write(msg_deleted); + case memcache_ascii_parser::state::cmd_decr: + { + auto result = _cache.decr(_parser._key, _parser._u64); + if (_parser._noreply) { + return make_ready_future<>(); } - return out.write(msg_not_found); + auto item = result.first; + if (!item) { + return out.write(msg_not_found); + } + auto decremented = result.second; + if (!decremented) { + return out.write(msg_error_non_numeric_value); + } + return out.write(item->data()._data).then([&out] { + return out.write(msg_crlf); + }); + } }; return make_ready_future<>(); }); @@ -276,7 +607,7 @@ class udp_server { public: static const size_t default_max_datagram_size = 1400; private: - ascii_protocol& _proto; + ascii_protocol _proto; udp_channel _chan; uint16_t _port; size_t _max_datagram_size = default_max_datagram_size; @@ -294,8 +625,8 @@ private: } __attribute__((packed)); public: - udp_server(ascii_protocol& proto, uint16_t port = 11211) - : _proto(proto) + udp_server(cache& c, system_stats& system_stats, uint16_t port = 11211) + : _proto(c, system_stats) , _port(port) {} @@ -367,6 +698,7 @@ class tcp_server { private: shared_ptr _listener; cache& _cache; + system_stats& _system_stats; uint16_t _port; struct connection { connected_socket _socket; @@ -374,23 +706,36 @@ private: input_stream _in; output_stream _out; ascii_protocol _proto; - connection(connected_socket&& socket, socket_address addr, cache& c) + system_stats& _system_stats; + connection(connected_socket&& socket, socket_address addr, cache& c, system_stats& system_stats) : _socket(std::move(socket)) , _addr(addr) , _in(_socket.input()) , _out(_socket.output()) - , _proto(c) - {} + , _proto(c, system_stats) + , _system_stats(system_stats) + { + _system_stats._curr_connections++; + _system_stats._total_connections++; + } + ~connection() { + _system_stats._curr_connections--; + } }; public: - tcp_server(cache& cache, uint16_t port = 11211) : _cache(cache), _port(port) {} + tcp_server(cache& cache, system_stats& system_stats, uint16_t port = 11211) + : _cache(cache) + , _system_stats(system_stats) + , _port(port) + {} + void start() { listen_options lo; lo.reuse_address = true; _listener = engine.listen(make_ipv4_address({_port}), lo); keep_doing([this] { return _listener->accept().then([this] (connected_socket fd, socket_address addr) mutable { - auto conn = make_shared(std::move(fd), addr, _cache); + auto conn = make_shared(std::move(fd), addr, _cache, _system_stats); do_until([conn] { return conn->_in.eof(); }, [this, conn] { return conn->_proto.handle(conn->_in, conn->_out).then([conn] { return conn->_out.flush(); @@ -431,11 +776,13 @@ public: int main(int ac, char** av) { memcache::cache cache; - memcache::ascii_protocol ascii_protocol(cache); - memcache::udp_server udp_server(ascii_protocol); - memcache::tcp_server tcp_server(cache); + memcache::system_stats system_stats; + memcache::udp_server udp_server(cache, system_stats); + memcache::tcp_server tcp_server(cache, system_stats); memcache::stats_printer stats(cache); + system_stats._start_time = clock_type::now(); + app_template app; app.add_options() ("max-datagram-size", bpo::value()->default_value(memcache::udp_server::default_max_datagram_size), diff --git a/core/sstring.hh b/core/sstring.hh index ee897e2a96..b127e5ae96 100644 --- a/core/sstring.hh +++ b/core/sstring.hh @@ -243,6 +243,12 @@ string_type to_sstring(unsigned long long value, void* = nullptr) { return to_sstring_sprintf(value, "%llu"); } +template +inline +string_type to_sstring(const char* value, void* = nullptr) { + return string_type(value); +} + template inline std::ostream& operator<<(std::ostream& os, const std::vector& v) { diff --git a/test.py b/test.py index 6c8c3452fa..eb693e99cd 100755 --- a/test.py +++ b/test.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import os import sys +import argparse import subprocess all_tests = [ @@ -18,13 +19,17 @@ def print_status(msg): print('\r' + msg, end='') if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Seastar test runner") + parser.add_argument('--fast', action="store_true", help="Run only fast tests") + args = parser.parse_args() + black_hole = open('/dev/null', 'w') test_to_run = [] for mode in ['debug', 'release']: for test in all_tests: test_to_run.append(os.path.join('build', mode, 'tests', test)) - test_to_run.append('tests/memcache/test.py ' + os.path.join('build', mode, 'apps', 'memcache', 'memcache') + ' --smp 1') + test_to_run.append('tests/memcache/test.py --mode ' + mode + (' --fast' if args.fast else '')) all_ok = True diff --git a/tests/memcache/test.py b/tests/memcache/test.py index 0b94f1a68c..dd05603310 100755 --- a/tests/memcache/test.py +++ b/tests/memcache/test.py @@ -1,24 +1,28 @@ #!/usr/bin/env python3 import time -import subprocess import sys +import os +import argparse +import subprocess -if len(sys.argv) < 2: - print('Usage: %s ...' % sys.argv[0]) - -memcache_path = sys.argv[1] - -def run(cmd): - mc = subprocess.Popen([memcache_path] + sys.argv[2:]) +def run(args, cmd): + mc = subprocess.Popen([os.path.join('build', args.mode, 'apps', 'memcache', 'memcache'), '--smp', '1']) print('Memcache started.') try: - time.sleep(0.1) cmdline = ['tests/memcache/test_memcache.py'] + cmd + if args.fast: + cmdline.append('--fast') print('Running: ' + ' '.join(cmdline)) subprocess.check_call(cmdline) finally: print('Killing memcache...') mc.kill() -run([]) -run(['-U']) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Seastar test runner") + parser.add_argument('--fast', action="store_true", help="Run only fast tests") + parser.add_argument('--mode', action="store", help="Test app in given mode", default='release') + args = parser.parse_args() + + run(args, []) + run(args, ['-U']) diff --git a/tests/memcache/test_memcache.py b/tests/memcache/test_memcache.py index 0b94e12bc6..f3522df6ac 100755 --- a/tests/memcache/test_memcache.py +++ b/tests/memcache/test_memcache.py @@ -5,14 +5,20 @@ import struct import random import argparse import time +import re import unittest server_addr = None call = None +args = None + +class TimeoutError(Exception): + pass @contextmanager -def tcp_connection(): +def tcp_connection(timeout=1): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(timeout) s.connect(server_addr) def call(msg): s.send(msg.encode()) @@ -20,6 +26,13 @@ def tcp_connection(): yield call s.close() +def slow(f): + def wrapper(self): + if args.fast: + raise unittest.SkipTest('Slow') + return f(self) + return wrapper + def recv_all(s): m = b'' while True: @@ -29,8 +42,9 @@ def recv_all(s): m += data return m -def tcp_call(msg): +def tcp_call(msg, timeout=1): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(timeout) s.connect(server_addr) s.send(msg.encode()) s.shutdown(socket.SHUT_WR) @@ -38,8 +52,9 @@ def tcp_call(msg): s.close() return data -def udp_call(msg): +def udp_call(msg, timeout=1): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(timeout) this_req_id = random.randint(-32768, 32767) datagram = struct.pack(">hhhh", this_req_id, 0, 1, 0) + msg.encode() @@ -73,7 +88,43 @@ def udp_call(msg): sock.close() return msg -class TcpSpecificTests(unittest.TestCase): +class MemcacheTest(unittest.TestCase): + def set(self, key, value, flags=0, expiry=0): + self.assertEqual(call('set %s %d %d %d\r\n%s\r\n' % (key, flags, expiry, len(value), value)), b'STORED\r\n') + + def delete(self, key): + self.assertEqual(call('delete %s\r\n' % key), b'DELETED\r\n') + + def assertHasKey(self, key): + resp = call('get %s\r\n' % key) + if not resp.startswith(('VALUE %s' % key).encode()): + self.fail('Key \'%s\' should be present, but got: %s' % (key, resp.decode())) + + def assertNoKey(self, key): + resp = call('get %s\r\n' % key) + if resp != b'END\r\n': + self.fail('Key \'%s\' should not be present, but got: %s' % (key, resp.decode())) + + def setKey(self, key): + self.set(key, 'some value') + + def getItemVersion(self, key): + m = re.match(r'VALUE %s \d+ \d+ (?P\d+)' % key, call('gets %s\r\n' % key).decode()) + return int(m.group('version')) + + def getStat(self, name, call_fn=None): + if not call_fn: call_fn = call + resp = call_fn('stats\r\n').decode() + m = re.search(r'STAT %s (?P.+)' % re.escape(name), resp, re.MULTILINE) + return m.group('value') + + def flush(self): + self.assertEqual(call('flush_all\r\n'), b'OK\r\n') + + def tearDown(self): + self.flush() + +class TcpSpecificTests(MemcacheTest): def test_recovers_from_errors_in_the_stream(self): with tcp_connection() as conn: self.assertEqual(conn('get\r\n'), b'ERROR\r\n') @@ -97,17 +148,57 @@ class TcpSpecificTests(unittest.TestCase): def test_unsuccesful_parsing_does_not_leave_data_behind(self): with tcp_connection() as conn: self.assertEqual(conn('set key 0 0 5\r\nhello\r\n'), b'STORED\r\n') - self.assertEqual(conn('delete a b c\r\n'), b'ERROR\r\n') + self.assertRegexpMatches(conn('delete a b c\r\n'), b'^(CLIENT_)?ERROR.*\r\n$') self.assertEqual(conn('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') self.assertEqual(conn('delete key\r\n'), b'DELETED\r\n') -class TestCommands(unittest.TestCase): - def call_set(self, key, value, flags=0, expiry=0): - self.assertEqual(call('set %s %d %d %d\r\n%s\r\n' % (key, flags, expiry, len(value), value)), b'STORED\r\n') + def test_flush_all_no_reply(self): + self.assertEqual(call('flush_all noreply\r\n'), b'') - def call_delete(self, key): - self.assertEqual(call('delete %s\r\n' % key), b'DELETED\r\n') + def test_set_no_reply(self): + self.assertEqual(call('set key 0 0 5 noreply\r\nhello\r\nget key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') + self.delete('key') + def test_delete_no_reply(self): + self.setKey('key') + self.assertEqual(call('delete key noreply\r\nget key\r\n'), b'END\r\n') + + def test_add_no_reply(self): + self.assertEqual(call('add key 0 0 1 noreply\r\na\r\nget key\r\n'), b'VALUE key 0 1\r\na\r\nEND\r\n') + self.delete('key') + + def test_replace_no_reply(self): + self.assertEqual(call('set key 0 0 1\r\na\r\n'), b'STORED\r\n') + self.assertEqual(call('replace key 0 0 1 noreply\r\nb\r\nget key\r\n'), b'VALUE key 0 1\r\nb\r\nEND\r\n') + self.delete('key') + + def test_cas_noreply(self): + self.assertNoKey('key') + self.assertEqual(call('cas key 0 0 1 1 noreply\r\na\r\n'), b'') + self.assertNoKey('key') + + self.assertEqual(call('add key 0 0 5\r\nhello\r\n'), b'STORED\r\n') + version = self.getItemVersion('key') + + self.assertEqual(call('cas key 1 0 5 %d noreply\r\naloha\r\n' % (version + 1)), b'') + self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') + + self.assertEqual(call('cas key 1 0 5 %d noreply\r\naloha\r\n' % (version)), b'') + self.assertEqual(call('get key\r\n'), b'VALUE key 1 5\r\naloha\r\nEND\r\n') + + self.delete('key') + + def test_connection_statistics(self): + with tcp_connection() as conn: + curr_connections = int(self.getStat('curr_connections', call_fn=conn)) + total_connections = int(self.getStat('total_connections', call_fn=conn)) + with tcp_connection() as conn2: + self.assertEquals(curr_connections + 1, int(self.getStat('curr_connections', call_fn=conn))) + self.assertEquals(total_connections + 1, int(self.getStat('total_connections', call_fn=conn))) + self.assertEquals(curr_connections, int(self.getStat('curr_connections', call_fn=conn))) + self.assertEquals(total_connections + 1, int(self.getStat('total_connections', call_fn=conn))) + +class TestCommands(MemcacheTest): def test_basic_commands(self): self.assertEqual(call('get key\r\n'), b'END\r\n') self.assertEqual(call('set key 0 0 5\r\nhello\r\n'), b'STORED\r\n') @@ -119,12 +210,14 @@ class TestCommands(unittest.TestCase): def test_error_handling(self): self.assertEqual(call('get\r\n'), b'ERROR\r\n') + @slow def test_expiry(self): self.assertEqual(call('set key 0 1 5\r\nhello\r\n'), b'STORED\r\n') self.assertEqual(call('get key\r\n'), b'VALUE key 0 5\r\nhello\r\nEND\r\n') time.sleep(1) self.assertEqual(call('get key\r\n'), b'END\r\n') + @slow def test_expiry_at_epoch_time(self): expiry = int(time.time()) + 1 self.assertEqual(call('set key 0 %d 5\r\nhello\r\n' % expiry), b'STORED\r\n') @@ -136,35 +229,316 @@ class TestCommands(unittest.TestCase): self.assertEqual(call('set key1 0 0 2\r\nv1\r\n'), b'STORED\r\n') self.assertEqual(call('set key 0 0 2\r\nv2\r\n'), b'STORED\r\n') self.assertEqual(call('get key1 key\r\n'), b'VALUE key1 0 2\r\nv1\r\nVALUE key 0 2\r\nv2\r\nEND\r\n') - self.call_delete("key") - self.call_delete("key1") + self.delete("key") + self.delete("key1") + + def test_flush_all(self): + self.set('key', 'value') + self.assertEqual(call('flush_all\r\n'), b'OK\r\n') + self.assertNoKey('key') + + def test_keys_set_after_flush_remain(self): + self.assertEqual(call('flush_all\r\n'), b'OK\r\n') + self.setKey('key') + self.assertHasKey('key') + self.delete('key') + + @slow + def test_flush_all_with_timeout_flushes_all_keys_even_those_set_after_flush(self): + self.setKey('key') + self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n') + self.assertHasKey('key') + self.setKey('key2') + time.sleep(2) + self.assertNoKey('key') + self.assertNoKey('key2') + + @slow + def test_subsequent_flush_is_merged(self): + self.setKey('key') + self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n') # Can flush in anything between 1-2 + self.assertEqual(call('flush_all 4\r\n'), b'OK\r\n') # Can flush in anything between 3-4 + time.sleep(2) + self.assertHasKey('key') + self.setKey('key2') + time.sleep(4) + self.assertNoKey('key') + self.assertNoKey('key2') + + @slow + def test_immediate_flush_cancels_delayed_flush(self): + self.assertEqual(call('flush_all 2\r\n'), b'OK\r\n') + self.assertEqual(call('flush_all\r\n'), b'OK\r\n') + self.setKey('key') + time.sleep(1) + self.assertHasKey('key') + self.delete('key') + + @slow + def test_flushing_in_the_past(self): + self.setKey('key1') + time.sleep(1) + self.setKey('key2') + key2_time = int(time.time()) + self.assertEqual(call('flush_all %d\r\n' % (key2_time - 1)), b'OK\r\n') + self.assertNoKey("key1") + self.assertNoKey("key2") + + @slow + def test_memcache_does_not_crash_when_flushing_with_already_expred_items(self): + self.assertEqual(call('set key1 0 2 5\r\nhello\r\n'), b'STORED\r\n') + time.sleep(1) + self.assertEqual(call('flush_all\r\n'), b'OK\r\n') def test_response_spanning_many_datagrams(self): key1_data = '1' * 1000 key2_data = '2' * 1000 key3_data = '3' * 1000 - self.call_set('key1', key1_data) - self.call_set('key2', key2_data) - self.call_set('key3', key3_data) + self.set('key1', key1_data) + self.set('key2', key2_data) + self.set('key3', key3_data) self.assertEqual(call('get key1 key2 key3\r\n').decode(), 'VALUE key1 0 %d\r\n%s\r\n' \ 'VALUE key2 0 %d\r\n%s\r\n' \ 'VALUE key3 0 %d\r\n%s\r\n' \ 'END\r\n' % (len(key1_data), key1_data, len(key2_data), key2_data, len(key3_data), key3_data)) - self.call_delete('key1') - self.call_delete('key2') - self.call_delete('key3') + self.delete('key1') + self.delete('key2') + self.delete('key3') + + def test_version(self): + self.assertRegexpMatches(call('version\r\n'), b'^VERSION .*\r\n$') + + def test_add(self): + self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'STORED\r\n') + self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'NOT_STORED\r\n') + self.delete('key') + + def test_replace(self): + self.assertEqual(call('add key 0 0 1\r\na\r\n'), b'STORED\r\n') + self.assertEqual(call('replace key 0 0 1\r\na\r\n'), b'STORED\r\n') + self.delete('key') + self.assertEqual(call('replace key 0 0 1\r\na\r\n'), b'NOT_STORED\r\n') + + def test_cas_and_gets(self): + self.assertEqual(call('cas key 0 0 1 1\r\na\r\n'), b'NOT_FOUND\r\n') + self.assertEqual(call('add key 0 0 5\r\nhello\r\n'), b'STORED\r\n') + version = self.getItemVersion('key') + + self.assertEqual(call('set key 1 0 5\r\nhello\r\n'), b'STORED\r\n') + self.assertEqual(call('gets key\r\n').decode(), 'VALUE key 1 5 %d\r\nhello\r\nEND\r\n' % (version + 1)) + + self.assertEqual(call('cas key 0 0 5 %d\r\nhello\r\n' % (version)), b'EXISTS\r\n') + self.assertEqual(call('cas key 0 0 5 %d\r\naloha\r\n' % (version + 1)), b'STORED\r\n') + self.assertEqual(call('gets key\r\n').decode(), 'VALUE key 0 5 %d\r\naloha\r\nEND\r\n' % (version + 2)) + + self.delete('key') + + def test_curr_items_stat(self): + self.assertEquals(0, int(self.getStat('curr_items'))) + self.setKey('key') + self.assertEquals(1, int(self.getStat('curr_items'))) + self.delete('key') + self.assertEquals(0, int(self.getStat('curr_items'))) + + def test_how_stats_change_with_different_commands(self): + get_count = int(self.getStat('cmd_get')) + set_count = int(self.getStat('cmd_set')) + flush_count = int(self.getStat('cmd_flush')) + total_items = int(self.getStat('total_items')) + get_misses = int(self.getStat('get_misses')) + get_hits = int(self.getStat('get_hits')) + cas_hits = int(self.getStat('cas_hits')) + cas_badval = int(self.getStat('cas_badval')) + cas_misses = int(self.getStat('cas_misses')) + delete_misses = int(self.getStat('delete_misses')) + delete_hits = int(self.getStat('delete_hits')) + curr_connections = int(self.getStat('curr_connections')) + incr_hits = int(self.getStat('incr_hits')) + incr_misses = int(self.getStat('incr_misses')) + decr_hits = int(self.getStat('decr_hits')) + decr_misses = int(self.getStat('decr_misses')) + + call('get key\r\n') + get_count += 1 + get_misses += 1 + + call('gets key\r\n') + get_count += 1 + get_misses += 1 + + call('set key1 0 0 1\r\na\r\n') + set_count += 1 + total_items += 1 + + call('get key1\r\n') + get_count += 1 + get_hits += 1 + + call('add key1 0 0 1\r\na\r\n') + set_count += 1 + + call('add key2 0 0 1\r\na\r\n') + set_count += 1 + total_items += 1 + + call('replace key1 0 0 1\r\na\r\n') + set_count += 1 + total_items += 1 + + call('replace key3 0 0 1\r\na\r\n') + set_count += 1 + + call('cas key4 0 0 1 1\r\na\r\n') + set_count += 1 + cas_misses += 1 + + call('cas key1 0 0 1 %d\r\na\r\n' % self.getItemVersion('key1')) + set_count += 1 + get_count += 1 + get_hits += 1 + cas_hits += 1 + total_items += 1 + + call('cas key1 0 0 1 %d\r\na\r\n' % (self.getItemVersion('key1') + 1)) + set_count += 1 + get_count += 1 + get_hits += 1 + cas_badval += 1 + + call('delete key1\r\n') + delete_hits += 1 + + call('delete key1\r\n') + delete_misses += 1 + + call('incr num 1\r\n') + incr_misses += 1 + call('decr num 1\r\n') + decr_misses += 1 + + call('set num 0 0 1\r\n0\r\n') + set_count += 1 + total_items += 1 + + call('incr num 1\r\n') + incr_hits += 1 + call('decr num 1\r\n') + decr_hits += 1 + + self.flush() + flush_count += 1 + + self.assertEquals(get_count, int(self.getStat('cmd_get'))) + self.assertEquals(set_count, int(self.getStat('cmd_set'))) + self.assertEquals(flush_count, int(self.getStat('cmd_flush'))) + self.assertEquals(total_items, int(self.getStat('total_items'))) + self.assertEquals(get_hits, int(self.getStat('get_hits'))) + self.assertEquals(get_misses, int(self.getStat('get_misses'))) + self.assertEquals(cas_misses, int(self.getStat('cas_misses'))) + self.assertEquals(cas_hits, int(self.getStat('cas_hits'))) + self.assertEquals(cas_badval, int(self.getStat('cas_badval'))) + self.assertEquals(delete_misses, int(self.getStat('delete_misses'))) + self.assertEquals(delete_hits, int(self.getStat('delete_hits'))) + self.assertEquals(0, int(self.getStat('curr_items'))) + self.assertEquals(curr_connections, int(self.getStat('curr_connections'))) + self.assertEquals(incr_misses, int(self.getStat('incr_misses'))) + self.assertEquals(incr_hits, int(self.getStat('incr_hits'))) + self.assertEquals(decr_misses, int(self.getStat('decr_misses'))) + self.assertEquals(decr_hits, int(self.getStat('decr_hits'))) + + def test_incr(self): + self.assertEqual(call('incr key 0\r\n'), b'NOT_FOUND\r\n') + + self.assertEqual(call('set key 0 0 1\r\n0\r\n'), b'STORED\r\n') + self.assertEqual(call('incr key 0\r\n'), b'0\r\n') + self.assertEqual(call('get key\r\n'), b'VALUE key 0 1\r\n0\r\nEND\r\n') + + self.assertEqual(call('incr key 1\r\n'), b'1\r\n') + self.assertEqual(call('incr key 2\r\n'), b'3\r\n') + self.assertEqual(call('incr key %d\r\n' % (pow(2, 64) - 1)), b'2\r\n') + self.assertEqual(call('incr key %d\r\n' % (pow(2, 64) - 3)), b'18446744073709551615\r\n') + self.assertRegexpMatches(call('incr key 1\r\n').decode(), r'0(\w*)?\r\n') + + self.assertEqual(call('set key 0 0 2\r\n1 \r\n'), b'STORED\r\n') + self.assertEqual(call('incr key 1\r\n'), b'2\r\n') + + self.assertEqual(call('set key 0 0 2\r\n09\r\n'), b'STORED\r\n') + self.assertEqual(call('incr key 1\r\n'), b'10\r\n') + + def test_decr(self): + self.assertEqual(call('decr key 0\r\n'), b'NOT_FOUND\r\n') + + self.assertEqual(call('set key 0 0 1\r\n7\r\n'), b'STORED\r\n') + self.assertEqual(call('decr key 1\r\n'), b'6\r\n') + self.assertEqual(call('get key\r\n'), b'VALUE key 0 1\r\n6\r\nEND\r\n') + + self.assertEqual(call('decr key 6\r\n'), b'0\r\n') + self.assertEqual(call('decr key 2\r\n'), b'0\r\n') + + self.assertEqual(call('set key 0 0 2\r\n20\r\n'), b'STORED\r\n') + self.assertRegexpMatches(call('decr key 11\r\n').decode(), r'^9( )?\r\n$') + + self.assertEqual(call('set key 0 0 3\r\n100\r\n'), b'STORED\r\n') + self.assertRegexpMatches(call('decr key 91\r\n').decode(), r'^9( )?\r\n$') + + self.assertEqual(call('set key 0 0 2\r\n1 \r\n'), b'STORED\r\n') + self.assertEqual(call('decr key 1\r\n'), b'0\r\n') + + self.assertEqual(call('set key 0 0 2\r\n09\r\n'), b'STORED\r\n') + self.assertEqual(call('decr key 1\r\n'), b'8\r\n') + + def test_incr_and_decr_on_invalid_input(self): + error_msg = b'CLIENT_ERROR cannot increment or decrement non-numeric value\r\n' + for cmd in ['incr', 'decr']: + for value in ['', '-1', 'a', '0x1', '18446744073709551616']: + self.assertEqual(call('set key 0 0 %d\r\n%s\r\n' % (len(value), value)), b'STORED\r\n') + prev = call('get key\r\n') + self.assertEqual(call(cmd + ' key 1\r\n'), error_msg, "cmd=%s, value=%s" % (cmd, value)) + self.assertEqual(call('get key\r\n'), prev) + self.delete('key') + +def wait_for_memcache_tcp(timeout=4): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + timeout_at = time.time() + timeout + while True: + if time.time() >= timeout_at: + raise TimeoutError() + try: + s.connect(server_addr) + s.close() + break + except ConnectionRefusedError: + time.sleep(0.1) + + +def wait_for_memcache_udp(timeout=4): + timeout_at = time.time() + timeout + while True: + if time.time() >= timeout_at: + raise TimeoutError() + try: + udp_call('version\r\n', timeout=0.2) + break + except socket.timeout: + pass if __name__ == '__main__': parser = argparse.ArgumentParser(description="memcache protocol tests") parser.add_argument('--server', '-s', action="store", help="server adddress in : format", default="localhost:11211") parser.add_argument('--udp', '-U', action="store_true", help="Use UDP protocol") + parser.add_argument('--fast', action="store_true", help="Run only fast tests") args = parser.parse_args() host, port = args.server.split(':') server_addr = (host, int(port)) - call = udp_call if args.udp else tcp_call + if args.udp: + call = udp_call + wait_for_memcache_udp() + else: + call = tcp_call + wait_for_memcache_tcp() runner = unittest.TextTestRunner() loader = unittest.TestLoader()