diff --git a/alternator-test/test_expected.py b/alternator-test/test_expected.py index acdc711301..29d40fea41 100644 --- a/alternator-test/test_expected.py +++ b/alternator-test/test_expected.py @@ -524,7 +524,6 @@ def test_update_expected_1_null(test_table_s): ) # Tests for Expected with ComparisonOperator = "CONTAINS": -@pytest.mark.xfail(reason="ComparisonOperator=CONTAINS in Expected not yet implemented") def test_update_expected_1_contains(test_table_s): # true cases. CONTAINS can be used for two unrelated things: check substrings # (in string or binary) and membership (in set or list). @@ -607,7 +606,6 @@ def test_update_expected_1_contains(test_table_s): ) # Tests for Expected with ComparisonOperator = "NOT_CONTAINS": -@pytest.mark.xfail(reason="ComparisonOperator=NOT_CONTAINS in Expected not yet implemented") def test_update_expected_1_not_contains(test_table_s): # true cases. NOT_CONTAINS can be used for two unrelated things: check substrings # (in string or binary) and membership (in set or list). diff --git a/alternator/conditions.cc b/alternator/conditions.cc index c8a5e530cd..6f4fcc1995 100644 --- a/alternator/conditions.cc +++ b/alternator/conditions.cc @@ -29,6 +29,7 @@ #include "rjson.hh" #include "serialization.hh" #include "base64.hh" +#include namespace alternator { @@ -47,7 +48,9 @@ comparison_operator_type get_comparison_operator(const rjson::value& comparison_ {"NOT_NULL", comparison_operator_type::NOT_NULL}, {"BETWEEN", comparison_operator_type::BETWEEN}, {"BEGINS_WITH", comparison_operator_type::BEGINS_WITH}, - }; //TODO: CONTAINS + {"CONTAINS", comparison_operator_type::CONTAINS}, + {"NOT_CONTAINS", comparison_operator_type::NOT_CONTAINS}, + }; if (!comparison_operator.IsString()) { throw api_error("ValidationException", format("Invalid comparison operator definition {}", rjson::print(comparison_operator))); } @@ -179,6 +182,59 @@ static bool check_BEGINS_WITH(const rjson::value* v1, const rjson::value& v2) { return val1.substr(0, val2.size()) == val2; } +static std::string_view to_string_view(const rjson::value& v) { + return std::string_view(v.GetString(), v.GetStringLength()); +} + +static bool is_set_of(const rjson::value& type1, const rjson::value& type2) { + return (type2 == "S" && type1 == "SS") || (type2 == "N" && type1 == "NS") || (type2 == "B" && type1 == "BS"); +} + +// Check if two JSON-encoded values match with the CONTAINS relation +static bool check_CONTAINS(const rjson::value* v1, const rjson::value& v2) { + if (!v1) { + return false; + } + const auto& kv1 = *v1->MemberBegin(); + const auto& kv2 = *v2.MemberBegin(); + if (kv2.name != "S" && kv2.name != "N" && kv2.name != "B") { + throw api_error("ValidationException", + format("CONTAINS operator requires a single AttributeValue of type String, Number, or Binary, " + "got {} instead", kv2.name)); + } + if (kv1.name == "S" && kv2.name == "S") { + return to_string_view(kv1.value).find(to_string_view(kv2.value)) != std::string_view::npos; + } else if (kv1.name == "B" && kv2.name == "B") { + return base64_decode(kv1.value).find(base64_decode(kv2.value)) != bytes::npos; + } else if (is_set_of(kv1.name, kv2.name)) { + for (auto i = kv1.value.Begin(); i != kv1.value.End(); ++i) { + if (*i == kv2.value) { + return true; + } + } + } else if (kv1.name == "L") { + for (auto i = kv1.value.Begin(); i != kv1.value.End(); ++i) { + if (!i->IsObject() || i->MemberCount() != 1) { + clogger.error("check_CONTAINS received a list whose element is malformed"); + return false; + } + const auto& el = *i->MemberBegin(); + if (el.name == kv2.name && el.value == kv2.value) { + return true; + } + } + } + return false; +} + +// Check if two JSON-encoded values match with the NOT_CONTAINS relation +static bool check_NOT_CONTAINS(const rjson::value* v1, const rjson::value& v2) { + if (!v1) { + return false; + } + return !check_CONTAINS(v1, v2); +} + // Check if a JSON-encoded value equals any element of an array, which must have at least one element. static bool check_IN(const rjson::value* val, const rjson::value& array) { if (!array[0].IsObject() || array[0].MemberCount() != 1) { @@ -393,10 +449,14 @@ static bool verify_expected_one(const rjson::value& condition, const rjson::valu case comparison_operator_type::BETWEEN: verify_operand_count(attribute_value_list, exact_size(2), *comparison_operator); return check_BETWEEN(got, (*attribute_value_list)[0], (*attribute_value_list)[1]); - default: - // FIXME: implement all the missing types, so there will be no default here. - throw api_error("ValidationException", format("ComparisonOperator {} is not yet supported", *comparison_operator)); + case comparison_operator_type::CONTAINS: + verify_operand_count(attribute_value_list, exact_size(1), *comparison_operator); + return check_CONTAINS(got, (*attribute_value_list)[0]); + case comparison_operator_type::NOT_CONTAINS: + verify_operand_count(attribute_value_list, exact_size(1), *comparison_operator); + return check_NOT_CONTAINS(got, (*attribute_value_list)[0]); } + throw std::logic_error(format("Internal error: corrupted operator enum: {}", int(op))); } } diff --git a/alternator/conditions.hh b/alternator/conditions.hh index def0320de0..2aa845b5f6 100644 --- a/alternator/conditions.hh +++ b/alternator/conditions.hh @@ -37,7 +37,7 @@ namespace alternator { enum class comparison_operator_type { - EQ, NE, LE, LT, GE, GT, IN, BETWEEN, CONTAINS, IS_NULL, NOT_NULL, BEGINS_WITH + EQ, NE, LE, LT, GE, GT, IN, BETWEEN, CONTAINS, NOT_CONTAINS, IS_NULL, NOT_NULL, BEGINS_WITH }; comparison_operator_type get_comparison_operator(const rjson::value& comparison_operator);