From 002ac8263179c86389e96a2e7d33d8a6dd636a86 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Sun, 6 Oct 2019 07:52:45 -0700 Subject: [PATCH] S3 Select: Add parser support for lists. (#8329) --- pkg/s3select/json/record.go | 10 +- pkg/s3select/select_test.go | 220 ++++++++++++++++++++++++++++++ pkg/s3select/sql/aggregation.go | 21 ++- pkg/s3select/sql/analysis.go | 14 +- pkg/s3select/sql/evaluate.go | 139 ++++++++++++++----- pkg/s3select/sql/jsonpath.go | 35 +++-- pkg/s3select/sql/jsonpath_test.go | 2 +- pkg/s3select/sql/parser.go | 23 +++- pkg/s3select/sql/statement.go | 2 +- pkg/s3select/sql/value.go | 84 ++++++++++++ 10 files changed, 487 insertions(+), 63 deletions(-) diff --git a/pkg/s3select/json/record.go b/pkg/s3select/json/record.go index f327c28f7..473bdeba6 100644 --- a/pkg/s3select/json/record.go +++ b/pkg/s3select/json/record.go @@ -88,6 +88,8 @@ func (r *Record) Set(name string, value *sql.Value) error { v = nil } else if b, ok := value.ToBytes(); ok { v = RawJSON(b) + } else if arr, ok := value.ToArray(); ok { + v = arr } else { return fmt.Errorf("unsupported sql value %v and type %v", value, value.GetTypeString()) } @@ -109,8 +111,14 @@ func (r *Record) WriteCSV(writer io.Writer, fieldDelimiter rune) error { columnValue = "" case RawJSON: columnValue = string([]byte(val)) + case []interface{}: + b, err := json.Marshal(val) + if err != nil { + return err + } + columnValue = string(b) default: - return errors.New("Cannot marshal unhandled type") + return fmt.Errorf("Cannot marshal unhandled type: %T", kv.Value) } csvRecord = append(csvRecord, columnValue) } diff --git a/pkg/s3select/select_test.go b/pkg/s3select/select_test.go index 9d4ff417b..7467c040e 100644 --- a/pkg/s3select/select_test.go +++ b/pkg/s3select/select_test.go @@ -24,7 +24,10 @@ import ( "net/http" "os" "reflect" + "strings" "testing" + + "github.com/minio/minio-go/v6" ) type testResponseWriter struct { @@ -48,6 +51,223 @@ func (w *testResponseWriter) WriteHeader(statusCode int) { func (w *testResponseWriter) Flush() { } +func TestJSONQueries(t *testing.T) { + input := `{"id": 0,"title": "Test Record","desc": "Some text","synonyms": ["foo", "bar", "whatever"]} + {"id": 1,"title": "Second Record","desc": "another text","synonyms": ["some", "synonym", "value"]} + {"id": 2,"title": "Second Record","desc": "another text","numbers": [2, 3.0, 4]} + {"id": 3,"title": "Second Record","desc": "another text","nested": [[2, 3.0, 4], [7, 8.5, 9]]}` + var testTable = []struct { + name string + query string + requestXML []byte + wantResult string + }{ + { + name: "select-in-array-full", + query: `SELECT * from s3object s WHERE 'bar' IN s.synonyms[*]`, + wantResult: `{"id":0,"title":"Test Record","desc":"Some text","synonyms":["foo","bar","whatever"]}`, + }, + { + name: "simple-in-array", + query: `SELECT * from s3object s WHERE s.id IN (1,3)`, + wantResult: `{"id":1,"title":"Second Record","desc":"another text","synonyms":["some","synonym","value"]} +{"id":3,"title":"Second Record","desc":"another text","nested":[[2,3,4],[7,8.5,9]]}`, + }, + { + name: "select-in-array-single", + query: `SELECT synonyms from s3object s WHERE 'bar' IN s.synonyms[*] `, + wantResult: `{"synonyms":["foo","bar","whatever"]}`, + }, + { + name: "donatello-1", + query: `SELECT * from s3object s WHERE 'bar' in s.synonyms`, + wantResult: `{"id":0,"title":"Test Record","desc":"Some text","synonyms":["foo","bar","whatever"]}`, + }, + { + name: "donatello-2", + query: `SELECT * from s3object s WHERE 'bar' in s.synonyms[*]`, + wantResult: `{"id":0,"title":"Test Record","desc":"Some text","synonyms":["foo","bar","whatever"]}`, + }, + { + name: "donatello-3", + query: `SELECT * from s3object s WHERE 'value' IN s.synonyms[*]`, + wantResult: `{"id":1,"title":"Second Record","desc":"another text","synonyms":["some","synonym","value"]}`, + }, + { + name: "select-in-number", + query: `SELECT * from s3object s WHERE 4 in s.numbers[*]`, + wantResult: `{"id":2,"title":"Second Record","desc":"another text","numbers":[2,3,4]}`, + }, + { + name: "select-in-number-float", + query: `SELECT * from s3object s WHERE 3 in s.numbers[*]`, + wantResult: `{"id":2,"title":"Second Record","desc":"another text","numbers":[2,3,4]}`, + }, + { + name: "select-in-number-float-in-sql", + query: `SELECT * from s3object s WHERE 3.0 in s.numbers[*]`, + wantResult: `{"id":2,"title":"Second Record","desc":"another text","numbers":[2,3,4]}`, + }, + { + name: "select-in-list-match", + query: `SELECT * from s3object s WHERE (2,3,4) IN s.nested[*]`, + wantResult: `{"id":3,"title":"Second Record","desc":"another text","nested":[[2,3,4],[7,8.5,9]]}`, + }, + { + name: "select-in-nested-float", + query: `SELECT s.nested from s3object s WHERE 8.5 IN s.nested[*][*]`, + wantResult: `{"nested":[[2,3,4],[7,8.5,9]]}`, + }, + { + name: "select-in-combine-and", + query: `SELECT s.nested from s3object s WHERE (8.5 IN s.nested[*][*]) AND (s.id > 0)`, + wantResult: `{"nested":[[2,3,4],[7,8.5,9]]}`, + }, + { + name: "select-in-combine-and-no", + query: `SELECT s.nested from s3object s WHERE (8.5 IN s.nested[*][*]) AND (s.id = 0)`, + wantResult: ``, + }, + { + name: "select-in-nested-float-no-flat", + query: `SELECT s.nested from s3object s WHERE 8.5 IN s.nested[*]`, + wantResult: ``, + }, + { + name: "select-empty-field-result", + query: `SELECT * from s3object s WHERE s.nested[0][0] = 2`, + wantResult: `{"id":3,"title":"Second Record","desc":"another text","nested":[[2,3,4],[7,8.5,9]]}`, + }, + { + name: "select-arrays-specific", + query: `SELECT * from s3object s WHERE s.nested[1][0] = 7`, + wantResult: `{"id":3,"title":"Second Record","desc":"another text","nested":[[2,3,4],[7,8.5,9]]}`, + }, + { + name: "wrong-index-no-result", + query: `SELECT * from s3object s WHERE s.nested[0][0] = 7`, + wantResult: ``, + }, + { + name: "not-equal-result", + query: `SELECT * from s3object s WHERE s.nested[1][0] != 7`, + wantResult: `{"id":0,"title":"Test Record","desc":"Some text","synonyms":["foo","bar","whatever"]} +{"id":1,"title":"Second Record","desc":"another text","synonyms":["some","synonym","value"]} +{"id":2,"title":"Second Record","desc":"another text","numbers":[2,3,4]}`, + }, + { + name: "indexed-list-match", + query: `SELECT * from s3object s WHERE (7,8.5,9) IN s.nested[1]`, + wantResult: ``, + }, + { + name: "indexed-list-match-equals", + query: `SELECT * from s3object s WHERE (7,8.5,9) = s.nested[1]`, + wantResult: `{"id":3,"title":"Second Record","desc":"another text","nested":[[2,3,4],[7,8.5,9]]}`, + }, + { + name: "indexed-list-match-not-equals", + query: `SELECT * from s3object s WHERE (7,8.5,9) != s.nested[1]`, + wantResult: `{"id":0,"title":"Test Record","desc":"Some text","synonyms":["foo","bar","whatever"]} +{"id":1,"title":"Second Record","desc":"another text","synonyms":["some","synonym","value"]} +{"id":2,"title":"Second Record","desc":"another text","numbers":[2,3,4]}`, + }, + { + name: "index-wildcard-in", + query: `SELECT * from s3object s WHERE (8.5) IN s.nested[1][*]`, + wantResult: `{"id":3,"title":"Second Record","desc":"another text","nested":[[2,3,4],[7,8.5,9]]}`, + }, + { + name: "index-wildcard-in", + query: `SELECT * from s3object s WHERE (8.0+0.5) IN s.nested[1][*]`, + wantResult: `{"id":3,"title":"Second Record","desc":"another text","nested":[[2,3,4],[7,8.5,9]]}`, + }, + { + name: "select-output-field-as-csv", + requestXML: []byte(` + + SELECT s.synonyms from s3object s WHERE 'whatever' IN s.synonyms + SQL + + NONE + + DOCUMENT + + + + + + + + FALSE + +`), + wantResult: `"[""foo"",""bar"",""whatever""]"`, + }, + } + + defRequest := ` + + %s + SQL + + NONE + + DOCUMENT + + + + + + + + FALSE + +` + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + testReq := testCase.requestXML + if len(testReq) == 0 { + testReq = []byte(fmt.Sprintf(defRequest, testCase.query)) + } + s3Select, err := NewS3Select(bytes.NewReader(testReq)) + if err != nil { + t.Fatal(err) + } + + if err = s3Select.Open(func(offset, length int64) (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewBufferString(input)), nil + }); err != nil { + t.Fatal(err) + } + + w := &testResponseWriter{} + s3Select.Evaluate(w) + s3Select.Close() + resp := http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewReader(w.response)), + ContentLength: int64(len(w.response)), + } + res, err := minio.NewSelectResults(&resp, "testbucket") + if err != nil { + t.Error(err) + return + } + got, err := ioutil.ReadAll(res) + if err != nil { + t.Error(err) + return + } + gotS := strings.TrimSpace(string(got)) + if !reflect.DeepEqual(gotS, testCase.wantResult) { + t.Errorf("received response does not match with expected reply. Query: %s\ngot: %s\nwant:%s", testCase.query, gotS, testCase.wantResult) + } + }) + } +} + func TestCSVInput(t *testing.T) { var testTable = []struct { requestXML []byte diff --git a/pkg/s3select/sql/aggregation.go b/pkg/s3select/sql/aggregation.go index 11b36c340..2d8ed02ef 100644 --- a/pkg/s3select/sql/aggregation.go +++ b/pkg/s3select/sql/aggregation.go @@ -163,6 +163,16 @@ func (e *Expression) aggregateRow(r Record) error { return nil } +func (e *ListExpr) aggregateRow(r Record) error { + for _, ex := range e.Elements { + err := ex.aggregateRow(r) + if err != nil { + return err + } + } + return nil +} + func (e *AndCondition) aggregateRow(r Record) error { for _, ex := range e.Condition { err := ex.aggregateRow(r) @@ -200,11 +210,10 @@ func (e *ConditionOperand) aggregateRow(r Record) error { } return e.ConditionRHS.Between.End.aggregateRow(r) case e.ConditionRHS.In != nil: - for _, elt := range e.ConditionRHS.In.Expressions { - err = elt.aggregateRow(r) - if err != nil { - return err - } + elt := e.ConditionRHS.In.ListExpression + err = elt.aggregateRow(r) + if err != nil { + return err } return nil case e.ConditionRHS.Like != nil: @@ -255,6 +264,8 @@ func (e *UnaryTerm) aggregateRow(r Record) error { func (e *PrimaryTerm) aggregateRow(r Record) error { switch { + case e.ListExpr != nil: + return e.ListExpr.aggregateRow(r) case e.SubExpression != nil: return e.SubExpression.aggregateRow(r) case e.FuncCall != nil: diff --git a/pkg/s3select/sql/analysis.go b/pkg/s3select/sql/analysis.go index e0a5e8eba..3e84cad48 100644 --- a/pkg/s3select/sql/analysis.go +++ b/pkg/s3select/sql/analysis.go @@ -107,6 +107,13 @@ func (e *Condition) analyze(s *Select) (result qProp) { return } +func (e *ListExpr) analyze(s *Select) (result qProp) { + for _, ac := range e.Elements { + result.combine(ac.analyze(s)) + } + return +} + func (e *ConditionOperand) analyze(s *Select) (result qProp) { if e.ConditionRHS == nil { result = e.Operand.analyze(s) @@ -125,9 +132,7 @@ func (e *ConditionRHS) analyze(s *Select) (result qProp) { result.combine(e.Between.Start.analyze(s)) result.combine(e.Between.End.analyze(s)) case e.In != nil: - for _, elt := range e.In.Expressions { - result.combine(elt.analyze(s)) - } + result.combine(e.In.ListExpression.analyze(s)) case e.Like != nil: result.combine(e.Like.Pattern.analyze(s)) if e.Like.EscapeChar != nil { @@ -179,6 +184,9 @@ func (e *PrimaryTerm) analyze(s *Select) (result qProp) { } result = qProp{isRowFunc: true} + case e.ListExpr != nil: + result = e.ListExpr.analyze(s) + case e.SubExpression != nil: result = e.SubExpression.analyze(s) diff --git a/pkg/s3select/sql/evaluate.go b/pkg/s3select/sql/evaluate.go index 5b2d927b7..1a846aa38 100644 --- a/pkg/s3select/sql/evaluate.go +++ b/pkg/s3select/sql/evaluate.go @@ -19,6 +19,7 @@ package sql import ( "encoding/json" "errors" + "fmt" "strings" "github.com/bcicen/jstream" @@ -227,27 +228,73 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) { return FromBool(matchResult), nil } -func (e *In) evalInNode(r Record, arg *Value) (*Value, error) { - result := false - for _, elt := range e.Expressions { +func (e *ListExpr) evalNode(r Record) (*Value, error) { + res := make([]Value, len(e.Elements)) + if len(e.Elements) == 1 { + // If length 1, treat as single value. + return e.Elements[0].evalNode(r) + } + for i, elt := range e.Elements { + v, err := elt.evalNode(r) + if err != nil { + return nil, err + } + res[i] = *v + } + return FromArray(res), nil +} + +func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) { + // Compare two values in terms of in-ness. + var cmp func(a, b Value) bool + cmp = func(a, b Value) bool { + if a.Equals(b) { + return true + } + + // If elements, compare each. + aA, aOK := a.ToArray() + bA, bOK := b.ToArray() + if aOK && bOK { + if len(aA) != len(bA) { + return false + } + for i := range aA { + if !cmp(aA[i], bA[i]) { + return false + } + } + return true + } + // Try as numbers + aF, aOK := a.ToFloat() + bF, bOK := b.ToFloat() + + // FIXME: more type inference? + return aOK && bOK && aF == bF + } + + var rhs Value + if elt := e.ListExpression; elt != nil { eltVal, err := elt.evalNode(r) if err != nil { return nil, err } - - // FIXME: type inference? - - // Types must match. - if !arg.SameTypeAs(*eltVal) { - // match failed. - continue - } - if arg.Equals(*eltVal) { - result = true - break - } + rhs = *eltVal } - return FromBool(result), nil + + // If RHS is array compare each element. + if arr, ok := rhs.ToArray(); ok { + for _, element := range arr { + // If we have an array we are on the wrong level. + if cmp(element, *lhs) { + return FromBool(true), nil + } + } + return FromBool(false), nil + } + + return FromBool(cmp(rhs, *lhs)), nil } func (e *Operand) evalNode(r Record) (*Value, error) { @@ -333,42 +380,60 @@ func (e *JSONPath) evalNode(r Record) (*Value, error) { pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}} } - result, err := jsonpathEval(pathExpr, rowVal) + result, _, err := jsonpathEval(pathExpr, rowVal) if err != nil { return nil, err } - switch rval := result.(type) { - case string: - return FromString(rval), nil - case float64: - return FromFloat(rval), nil - case int64: - return FromInt(rval), nil - case bool: - return FromBool(rval), nil - case jstream.KVS, []interface{}: - bs, err := json.Marshal(result) - if err != nil { - return nil, err - } - return FromBytes(bs), nil - case nil: - return FromNull(), nil - default: - return nil, errors.New("Unhandled value type") - } + return jsonToValue(result) default: return r.Get(keypath) } } +// jsonToValue will convert the json value to an internal value. +func jsonToValue(result interface{}) (*Value, error) { + switch rval := result.(type) { + case string: + return FromString(rval), nil + case float64: + return FromFloat(rval), nil + case int64: + return FromInt(rval), nil + case bool: + return FromBool(rval), nil + case jstream.KVS: + bs, err := json.Marshal(result) + if err != nil { + return nil, err + } + return FromBytes(bs), nil + case []interface{}: + dst := make([]Value, len(rval)) + for i := range rval { + v, err := jsonToValue(rval[i]) + if err != nil { + return nil, err + } + dst[i] = *v + } + return FromArray(dst), nil + case []Value: + return FromArray(rval), nil + case nil: + return FromNull(), nil + } + return nil, fmt.Errorf("Unhandled value type: %T", result) +} + func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) { switch { case e.Value != nil: return e.Value.evalNode(r) case e.JPathExpr != nil: return e.JPathExpr.evalNode(r) + case e.ListExpr != nil: + return e.ListExpr.evalNode(r) case e.SubExpression != nil: return e.SubExpression.evalNode(r) case e.FuncCall != nil: diff --git a/pkg/s3select/sql/jsonpath.go b/pkg/s3select/sql/jsonpath.go index 88958576f..a2570f432 100644 --- a/pkg/s3select/sql/jsonpath.go +++ b/pkg/s3select/sql/jsonpath.go @@ -30,10 +30,12 @@ var ( errWilcardObjectUsageInvalid = errors.New("Invalid usage of object wildcard") ) -func jsonpathEval(p []*JSONPathElement, v interface{}) (r interface{}, err error) { +// jsonpathEval evaluates a JSON path and returns the value at the path. +// If the value should be considered flat (from wildcards) any array returned should be considered individual values. +func jsonpathEval(p []*JSONPathElement, v interface{}) (r interface{}, flat bool, err error) { // fmt.Printf("JPATHexpr: %v jsonobj: %v\n\n", p, v) if len(p) == 0 || v == nil { - return v, nil + return v, false, nil } switch { @@ -42,7 +44,7 @@ func jsonpathEval(p []*JSONPathElement, v interface{}) (r interface{}, err error kvs, ok := v.(jstream.KVS) if !ok { - return nil, errKeyLookup + return nil, false, errKeyLookup } for _, kv := range kvs { if kv.Key == key { @@ -50,51 +52,58 @@ func jsonpathEval(p []*JSONPathElement, v interface{}) (r interface{}, err error } } // Key not found - return nil result - return nil, nil + return nil, false, nil case p[0].Index != nil: idx := *p[0].Index arr, ok := v.([]interface{}) if !ok { - return nil, errIndexLookup + return nil, false, errIndexLookup } if idx >= len(arr) { - return nil, nil + return nil, false, nil } return jsonpathEval(p[1:], arr[idx]) case p[0].ObjectWildcard: kvs, ok := v.(jstream.KVS) if !ok { - return nil, errWildcardObjectLookup + return nil, false, errWildcardObjectLookup } if len(p[1:]) > 0 { - return nil, errWilcardObjectUsageInvalid + return nil, false, errWilcardObjectUsageInvalid } - return kvs, nil + return kvs, false, nil case p[0].ArrayWildcard: arr, ok := v.([]interface{}) if !ok { - return nil, errWildcardArrayLookup + return nil, false, errWildcardArrayLookup } // Lookup remainder of path in each array element and // make result array. var result []interface{} for _, a := range arr { - rval, err := jsonpathEval(p[1:], a) + rval, flatten, err := jsonpathEval(p[1:], a) if err != nil { - return nil, err + return nil, false, err } + if flatten { + // Flatten if array. + if arr, ok := rval.([]interface{}); ok { + result = append(result, arr...) + continue + } + } result = append(result, rval) } - return result, nil + return result, true, nil } panic("cannot reach here") } diff --git a/pkg/s3select/sql/jsonpath_test.go b/pkg/s3select/sql/jsonpath_test.go index 3ce571451..53fcd136c 100644 --- a/pkg/s3select/sql/jsonpath_test.go +++ b/pkg/s3select/sql/jsonpath_test.go @@ -83,7 +83,7 @@ func TestJsonpathEval(t *testing.T) { for j, rec := range recs { // fmt.Println(rec) - r, err := jsonpathEval(jp.PathExpr, rec) + r, _, err := jsonpathEval(jp.PathExpr, rec) if err != nil { t.Errorf("Error: %d %d %v", i, j, err) } diff --git a/pkg/s3select/sql/parser.go b/pkg/s3select/sql/parser.go index 041632189..d8d5c21ab 100644 --- a/pkg/s3select/sql/parser.go +++ b/pkg/s3select/sql/parser.go @@ -47,6 +47,19 @@ func (ls *LiteralString) Capture(values []string) error { return nil } +// LiteralList is a type for parsed SQL lists literals +type LiteralList []string + +// Capture interface used by participle +func (ls *LiteralList) Capture(values []string) error { + // Remove enclosing parenthesis. + n := len(values[0]) + r := values[0][1 : n-1] + // Translate doubled quotes + *ls = LiteralList(strings.Split(r, ",")) + return nil +} + // ObjectKey is a type for parsed strings occurring in key paths type ObjectKey struct { Lit *LiteralString `parser:" \"[\" @LitString \"]\""` @@ -134,6 +147,11 @@ type Expression struct { And []*AndCondition `parser:"@@ ( \"OR\" @@ )*"` } +// ListExpr represents a literal list with elements as expressions. +type ListExpr struct { + Elements []*Expression `parser:"\"(\" @@ ( \",\" @@ )* \")\""` +} + // AndCondition represents logical conjunction of clauses type AndCondition struct { Condition []*Condition `parser:"@@ ( \"AND\" @@ )*"` @@ -157,7 +175,7 @@ type ConditionOperand struct { type ConditionRHS struct { Compare *Compare `parser:" @@"` Between *Between `parser:"| @@"` - In *In `parser:"| \"IN\" \"(\" @@ \")\""` + In *In `parser:"| \"IN\" @@"` Like *Like `parser:"| @@"` } @@ -183,7 +201,7 @@ type Between struct { // In represents the RHS of an IN expression type In struct { - Expressions []*Expression `parser:"@@ ( \",\" @@ )*"` + ListExpression *Expression `parser:"@@ "` } // Grammar for Operand: @@ -236,6 +254,7 @@ type NegatedTerm struct { type PrimaryTerm struct { Value *LitValue `parser:" @@"` JPathExpr *JSONPath `parser:"| @@"` + ListExpr *ListExpr `parser:"| @@"` SubExpression *Expression `parser:"| \"(\" @@ \")\""` // Include function expressions here. FuncCall *FuncExpr `parser:"| @@"` diff --git a/pkg/s3select/sql/statement.go b/pkg/s3select/sql/statement.go index 23f2274ef..57814dfd8 100644 --- a/pkg/s3select/sql/statement.go +++ b/pkg/s3select/sql/statement.go @@ -133,7 +133,7 @@ func (e *SelectStatement) EvalFrom(format string, input Record) (Record, error) } jsonRec := rawVal.(jstream.KVS) - txedRec, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], jsonRec) + txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], jsonRec) if err != nil { return nil, err } diff --git a/pkg/s3select/sql/value.go b/pkg/s3select/sql/value.go index 93d4933ca..1b58735c8 100644 --- a/pkg/s3select/sql/value.go +++ b/pkg/s3select/sql/value.go @@ -17,6 +17,7 @@ package sql import ( + "encoding/json" "errors" "fmt" "math" @@ -46,6 +47,14 @@ type Value struct { value interface{} } +// MarshalJSON provides json marshaling of values. +func (v Value) MarshalJSON() ([]byte, error) { + if b, ok := v.ToBytes(); ok { + return b, nil + } + return json.Marshal(v.value) +} + // GetTypeString returns a string representation for vType func (v Value) GetTypeString() string { switch v.value.(type) { @@ -63,6 +72,8 @@ func (v Value) GetTypeString() string { return "TIMESTAMP" case []byte: return "BYTES" + case []Value: + return "ARRAY" } return "--" } @@ -80,6 +91,17 @@ func (v Value) Repr() string { return fmt.Sprintf("\"%s\":%s", x, v.GetTypeString()) case []byte: return fmt.Sprintf("\"%s\":BYTES", string(x)) + case []Value: + var s strings.Builder + s.WriteByte('[') + for i, v := range x { + s.WriteString(v.Repr()) + if i < len(x)-1 { + s.WriteByte(',') + } + } + s.WriteString("]:ARRAY") + return s.String() default: return fmt.Sprintf("%v:INVALID", v.value) } @@ -120,6 +142,11 @@ func FromBytes(b []byte) *Value { return &Value{value: b} } +// FromArray creates a Value from an array of values. +func FromArray(a []Value) *Value { + return &Value{value: a} +} + // ToFloat works for int and float values func (v Value) ToFloat() (val float64, ok bool) { switch x := v.value.(type) { @@ -167,6 +194,8 @@ func (v Value) SameTypeAs(b Value) (ok bool) { _, ok = b.value.(time.Time) case []byte: _, ok = b.value.([]byte) + case []Value: + _, ok = b.value.([]Value) default: ok = reflect.TypeOf(v.value) == reflect.TypeOf(b.value) } @@ -192,6 +221,12 @@ func (v Value) ToBytes() (val []byte, ok bool) { return } +// ToArray returns the value if it is a slice of values. +func (v Value) ToArray() (val []Value, ok bool) { + val, ok = v.value.([]Value) + return +} + // IsNull - checks if value is missing. func (v Value) IsNull() bool { switch v.value.(type) { @@ -201,6 +236,12 @@ func (v Value) IsNull() bool { return false } +// IsArray returns whether the value is an array. +func (v Value) IsArray() (ok bool) { + _, ok = v.value.([]Value) + return ok +} + func (v Value) isNumeric() bool { switch v.value.(type) { case int64, float64: @@ -255,6 +296,10 @@ func (v Value) CSVString() string { return FormatSQLTimestamp(x) case []byte: return string(x) + case []Value: + b, _ := json.Marshal(x) + return string(b) + default: return "CSV serialization not implemented for this type" } @@ -311,6 +356,19 @@ func (v *Value) compareOp(op string, a *Value) (res bool, err error) { return false, err } + // Check if either is nil + if v.IsNull() || a.IsNull() { + // If one is, both must be. + return boolCompare(op, v.IsNull(), a.IsNull()) + } + + // Check array values + aArr, aOK := a.ToArray() + vArr, vOK := v.ToArray() + if aOK && vOK { + return arrayCompare(op, aArr, vArr) + } + isNumeric := v.isNumeric() && a.isNumeric() if isNumeric { intV, ok1i := v.ToInt() @@ -725,6 +783,32 @@ func boolCompare(op string, left, right bool) (bool, error) { } } +func arrayCompare(op string, left, right []Value) (bool, error) { + switch op { + case opEq: + if len(left) != len(right) { + return false, nil + } + for i, l := range left { + eq, err := l.compareOp(op, &right[i]) + if !eq || err != nil { + return eq, err + } + } + return true, nil + case opIneq: + for i, l := range left { + eq, err := l.compareOp(op, &right[i]) + if eq || err != nil { + return eq, err + } + } + return false, nil + default: + return false, errCmpInvalidBoolOperator + } +} + func timestampCompare(op string, left, right time.Time) bool { switch op { case opLt: