Compare commits

...

16 Commits

Author SHA1 Message Date
Andreas Auernhammer
404d2ebe3f set SSE headers in put-part response (#12008)
This commit fixes a bug in the put-part
implementation. The SSE headers should be
set as specified by AWS - See:
https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPart.html

Now, the MinIO server should set SSE-C headers,
like `x-amz-server-side-encryption-customer-algorithm`.

Fixes #11991
2021-04-07 14:50:28 -07:00
Minio Trusted
46964eb764 Update yaml files to latest version RELEASE.2021-04-06T23-11-00Z 2021-04-06 23:35:33 +00:00
Poorna Krishnamoorthy
bfab990c33 Improve error message from SetRemoteTargetHandler (#11909) 2021-04-06 12:42:30 -07:00
Harshavardhana
94018588fe unmarshal both LegalHold and ObjectLockLegalHold XML types (#11921)
Because of silly AWS S3 behavior we to handle both types.

fixes #11920
2021-04-06 12:41:56 -07:00
Anis Elleuch
8b76ba8d5d crawling: Apply lifecycle then decide healing action (#11563)
It is inefficient to decide to heal an object before checking its
lifecycle for expiration or transition. This commit will just reverse
the order of action: evaluate lifecycle and heal only if asked and
lifecycle resulted a NoneAction.
2021-04-06 12:41:51 -07:00
Harshavardhana
7eb7f65e48 add policy conditions support for signatureVersion and authType (#11947)
https://docs.aws.amazon.com/AmazonS3/latest/API/bucket-policy-s3-sigv4-conditions.html

fixes #11944
2021-04-06 12:41:31 -07:00
Harshavardhana
c608c0688a fix: properly close leaking bandwidth monitor channel (#11967)
This PR fixes

- close leaking bandwidth report channel leakage
- remove the closer requirement for bandwidth monitor
  instead if Read() fails remember the error and return
  error for all subsequent reads.
- use locking for usage-cache.bin updates, with inline
  data we cannot afford to have concurrent writes to
  usage-cache.bin corrupting xl.meta
2021-04-06 12:40:42 -07:00
Aditya Manthramurthy
41a9d1d778 Fix S3Select SQL column reference handling (#11957)
This change fixes handling of these types of queries:

- Double quoted column names with special characters:
    SELECT "column.name" FROM s3object
- Double quoted column names with reserved keywords:
    SELECT "CAST" FROM s3object
- Table name as prefix for column names:
    SELECT S3Object."CAST" FROM s3object
2021-04-06 12:40:28 -07:00
Klaus Post
e21e80841e Fix data race when connecting disks (#11983)
Multiple disks from the same set would be writing concurrently.

```
WARNING: DATA RACE
Write at 0x00c002100ce0 by goroutine 166:
  github.com/minio/minio/cmd.(*erasureSets).connectDisks.func1()
      d:/minio/minio/cmd/erasure-sets.go:254 +0x82f

Previous write at 0x00c002100ce0 by goroutine 129:
  github.com/minio/minio/cmd.(*erasureSets).connectDisks.func1()
      d:/minio/minio/cmd/erasure-sets.go:254 +0x82f

Goroutine 166 (running) created at:
  github.com/minio/minio/cmd.(*erasureSets).connectDisks()
      d:/minio/minio/cmd/erasure-sets.go:210 +0x324
  github.com/minio/minio/cmd.(*erasureSets).monitorAndConnectEndpoints()
      d:/minio/minio/cmd/erasure-sets.go:288 +0x244

Goroutine 129 (finished) created at:
  github.com/minio/minio/cmd.(*erasureSets).connectDisks()
      d:/minio/minio/cmd/erasure-sets.go:210 +0x324
  github.com/minio/minio/cmd.(*erasureSets).monitorAndConnectEndpoints()
      d:/minio/minio/cmd/erasure-sets.go:288 +0x244
```
2021-04-06 12:39:59 -07:00
Klaus Post
98c792bbeb Fix disk info race (#11984)
Protect updated members in xlStorage.

```
WARNING: DATA RACE
Write at 0x00c004b4ee78 by goroutine 1491:
  github.com/minio/minio/cmd.(*xlStorage).GetDiskID()
      d:/minio/minio/cmd/xl-storage.go:590 +0x1078
  github.com/minio/minio/cmd.(*xlStorageDiskIDCheck).checkDiskStale()
      d:/minio/minio/cmd/xl-storage-disk-id-check.go:195 +0x84
  github.com/minio/minio/cmd.(*xlStorageDiskIDCheck).StatVol()
      d:/minio/minio/cmd/xl-storage-disk-id-check.go:284 +0x16a
  github.com/minio/minio/cmd.erasureObjects.getBucketInfo.func1()
      d:/minio/minio/cmd/erasure-bucket.go:100 +0x1a5
  github.com/minio/minio/pkg/sync/errgroup.(*Group).Go.func1()
      d:/minio/minio/pkg/sync/errgroup/errgroup.go:122 +0xd7

Previous read at 0x00c004b4ee78 by goroutine 1087:
  github.com/minio/minio/cmd.(*xlStorage).CheckFile.func1()
      d:/minio/minio/cmd/xl-storage.go:1699 +0x384
  github.com/minio/minio/cmd.(*xlStorage).CheckFile()
      d:/minio/minio/cmd/xl-storage.go:1726 +0x13c
  github.com/minio/minio/cmd.(*xlStorageDiskIDCheck).CheckFile()
      d:/minio/minio/cmd/xl-storage-disk-id-check.go:446 +0x23b
  github.com/minio/minio/cmd.erasureObjects.parentDirIsObject.func1()
      d:/minio/minio/cmd/erasure-common.go:173 +0x194
  github.com/minio/minio/pkg/sync/errgroup.(*Group).Go.func1()
      d:/minio/minio/pkg/sync/errgroup/errgroup.go:122 +0xd7
```
2021-04-06 12:39:57 -07:00
Klaus Post
f687ba53bc Fix Access Key requests (#11979)
Fix accessing claims when auth error is unchecked.

Only replaced when unchecked and when clearly without side effects.

Fixes #11959
2021-04-06 11:03:55 -07:00
Harshavardhana
e3da59c923 fix possible crash in bucket bandwidth monitor (#11986) 2021-04-06 11:03:41 -07:00
Harshavardhana
781b9b051c fix: service accounts policy enforcement regression (#11910)
service accounts were not inheriting parent policies
anymore due to refactors in the PolicyDBGet() from
the latest release, fix this behavior properly.
2021-04-06 08:58:05 -07:00
Harshavardhana
438becfde8 fix: delete/delete marker replication versions consistent (#11932)
replication didn't work as expected when deletion of
delete markers was requested in DeleteMultipleObjects
API, this is due to incorrect lookup elements being
used to look for delete markers.
2021-04-06 08:57:36 -07:00
Harshavardhana
16ef338649 fix: notify parent user in notification events (#11934)
fixes #11885
2021-04-06 08:55:37 -07:00
Harshavardhana
3242847ec0 avoid network read errors crashing CreateFile call (#11939)
Thanks to @dvaldivia for reproducing this
2021-04-06 08:55:30 -07:00
37 changed files with 602 additions and 274 deletions

View File

@@ -172,7 +172,12 @@ func (a adminAPIHandlers) SetRemoteTargetHandler(w http.ResponseWriter, r *http.
} }
if err = globalBucketTargetSys.SetTarget(ctx, bucket, &target, update); err != nil { if err = globalBucketTargetSys.SetTarget(ctx, bucket, &target, update); err != nil {
writeErrorResponseJSON(ctx, w, toAPIError(ctx, err), r.URL) switch err.(type) {
case BucketRemoteConnectionErr:
writeErrorResponseJSON(ctx, w, errorCodes.ToAPIErrWithErr(ErrReplicationRemoteConnectionError, err), r.URL)
default:
writeErrorResponseJSON(ctx, w, toAPIError(ctx, err), r.URL)
}
return return
} }
targets, err := globalBucketTargetSys.ListBucketTargets(ctx, bucket) targets, err := globalBucketTargetSys.ListBucketTargets(ctx, bucket)

View File

@@ -24,6 +24,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math/rand"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@@ -1470,30 +1471,33 @@ func (a adminAPIHandlers) BandwidthMonitorHandler(w http.ResponseWriter, r *http
return return
} }
rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
setEventStreamHeaders(w) setEventStreamHeaders(w)
reportCh := make(chan bandwidth.Report, 1) reportCh := make(chan bandwidth.Report)
keepAliveTicker := time.NewTicker(500 * time.Millisecond) keepAliveTicker := time.NewTicker(500 * time.Millisecond)
defer keepAliveTicker.Stop() defer keepAliveTicker.Stop()
bucketsRequestedString := r.URL.Query().Get("buckets") bucketsRequestedString := r.URL.Query().Get("buckets")
bucketsRequested := strings.Split(bucketsRequestedString, ",") bucketsRequested := strings.Split(bucketsRequestedString, ",")
go func() { go func() {
defer close(reportCh)
for { for {
reportCh <- globalNotificationSys.GetBandwidthReports(ctx, bucketsRequested...)
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
default: case reportCh <- globalNotificationSys.GetBandwidthReports(ctx, bucketsRequested...):
time.Sleep(2 * time.Second) time.Sleep(time.Duration(rnd.Float64() * float64(2*time.Second)))
} }
} }
}() }()
for { for {
select { select {
case report := <-reportCh: case report, ok := <-reportCh:
enc := json.NewEncoder(w) if !ok {
err := enc.Encode(report) return
if err != nil { }
writeErrorResponseJSON(ctx, w, errorCodes.ToAPIErr(ErrInternalError), r.URL) if err := json.NewEncoder(w).Encode(report); err != nil {
writeErrorResponseJSON(ctx, w, toAPIError(ctx, err), r.URL)
return return
} }
w.(http.Flusher).Flush() w.(http.Flusher).Flush()

View File

@@ -496,7 +496,7 @@ func (api objectAPIHandlers) DeleteMultipleObjectsHandler(w http.ResponseWriter,
object.PurgeTransitioned = goi.TransitionStatus object.PurgeTransitioned = goi.TransitionStatus
} }
if replicateDeletes { if replicateDeletes {
delMarker, replicate, repsync := checkReplicateDelete(ctx, bucket, ObjectToDelete{ replicate, repsync := checkReplicateDelete(ctx, bucket, ObjectToDelete{
ObjectName: object.ObjectName, ObjectName: object.ObjectName,
VersionID: object.VersionID, VersionID: object.VersionID,
}, goi, gerr) }, goi, gerr)
@@ -511,9 +511,6 @@ func (api objectAPIHandlers) DeleteMultipleObjectsHandler(w http.ResponseWriter,
} }
if object.VersionID != "" { if object.VersionID != "" {
object.VersionPurgeStatus = Pending object.VersionPurgeStatus = Pending
if delMarker {
object.DeleteMarkerVersionID = object.VersionID
}
} else { } else {
object.DeleteMarkerReplicationStatus = string(replication.Pending) object.DeleteMarkerReplicationStatus = string(replication.Pending)
} }
@@ -557,13 +554,18 @@ func (api objectAPIHandlers) DeleteMultipleObjectsHandler(w http.ResponseWriter,
}) })
deletedObjects := make([]DeletedObject, len(deleteObjects.Objects)) deletedObjects := make([]DeletedObject, len(deleteObjects.Objects))
for i := range errs { for i := range errs {
dindex := objectsToDelete[ObjectToDelete{ // DeleteMarkerVersionID is not used specifically to avoid
// lookup errors, since DeleteMarkerVersionID is only
// created during DeleteMarker creation when client didn't
// specify a versionID.
objToDel := ObjectToDelete{
ObjectName: dObjects[i].ObjectName, ObjectName: dObjects[i].ObjectName,
VersionID: dObjects[i].VersionID, VersionID: dObjects[i].VersionID,
VersionPurgeStatus: dObjects[i].VersionPurgeStatus, VersionPurgeStatus: dObjects[i].VersionPurgeStatus,
DeleteMarkerReplicationStatus: dObjects[i].DeleteMarkerReplicationStatus, DeleteMarkerReplicationStatus: dObjects[i].DeleteMarkerReplicationStatus,
PurgeTransitioned: dObjects[i].PurgeTransitioned, PurgeTransitioned: dObjects[i].PurgeTransitioned,
}] }
dindex := objectsToDelete[objToDel]
if errs[i] == nil || isErrObjectNotFound(errs[i]) || isErrVersionNotFound(errs[i]) { if errs[i] == nil || isErrObjectNotFound(errs[i]) || isErrVersionNotFound(errs[i]) {
if replicateDeletes { if replicateDeletes {
dObjects[i].DeleteMarkerReplicationStatus = deleteList[i].DeleteMarkerReplicationStatus dObjects[i].DeleteMarkerReplicationStatus = deleteList[i].DeleteMarkerReplicationStatus
@@ -619,12 +621,12 @@ func (api objectAPIHandlers) DeleteMultipleObjectsHandler(w http.ResponseWriter,
eventName := event.ObjectRemovedDelete eventName := event.ObjectRemovedDelete
objInfo := ObjectInfo{ objInfo := ObjectInfo{
Name: dobj.ObjectName, Name: dobj.ObjectName,
VersionID: dobj.VersionID, VersionID: dobj.VersionID,
DeleteMarker: dobj.DeleteMarker,
} }
if dobj.DeleteMarker { if objInfo.DeleteMarker {
objInfo.DeleteMarker = dobj.DeleteMarker
objInfo.VersionID = dobj.DeleteMarkerVersionID objInfo.VersionID = dobj.DeleteMarkerVersionID
eventName = event.ObjectRemovedDeleteMarkerCreated eventName = event.ObjectRemovedDeleteMarkerCreated
} }

View File

@@ -83,17 +83,38 @@ func getConditionValues(r *http.Request, lc string, username string, claims map[
} }
} }
authType := getRequestAuthType(r)
var signatureVersion string
switch authType {
case authTypeSignedV2, authTypePresignedV2:
signatureVersion = signV2Algorithm
case authTypeSigned, authTypePresigned, authTypeStreamingSigned, authTypePostPolicy:
signatureVersion = signV4Algorithm
}
var authtype string
switch authType {
case authTypePresignedV2, authTypePresigned:
authtype = "REST-QUERY-STRING"
case authTypeSignedV2, authTypeSigned, authTypeStreamingSigned:
authtype = "REST-HEADER"
case authTypePostPolicy:
authtype = "POST"
}
args := map[string][]string{ args := map[string][]string{
"CurrentTime": {currTime.Format(time.RFC3339)}, "CurrentTime": {currTime.Format(time.RFC3339)},
"EpochTime": {strconv.FormatInt(currTime.Unix(), 10)}, "EpochTime": {strconv.FormatInt(currTime.Unix(), 10)},
"SecureTransport": {strconv.FormatBool(r.TLS != nil)}, "SecureTransport": {strconv.FormatBool(r.TLS != nil)},
"SourceIp": {handlers.GetSourceIP(r)}, "SourceIp": {handlers.GetSourceIP(r)},
"UserAgent": {r.UserAgent()}, "UserAgent": {r.UserAgent()},
"Referer": {r.Referer()}, "Referer": {r.Referer()},
"principaltype": {principalType}, "principaltype": {principalType},
"userid": {username}, "userid": {username},
"username": {username}, "username": {username},
"versionid": {vid}, "versionid": {vid},
"signatureversion": {signatureVersion},
"authType": {authtype},
} }
if lc != "" { if lc != "" {

View File

@@ -175,10 +175,10 @@ func isStandardHeader(matchHeaderKey string) bool {
} }
// returns whether object version is a deletemarker and if object qualifies for replication // returns whether object version is a deletemarker and if object qualifies for replication
func checkReplicateDelete(ctx context.Context, bucket string, dobj ObjectToDelete, oi ObjectInfo, gerr error) (dm, replicate, sync bool) { func checkReplicateDelete(ctx context.Context, bucket string, dobj ObjectToDelete, oi ObjectInfo, gerr error) (replicate, sync bool) {
rcfg, err := getReplicationConfig(ctx, bucket) rcfg, err := getReplicationConfig(ctx, bucket)
if err != nil || rcfg == nil { if err != nil || rcfg == nil {
return false, false, sync return false, sync
} }
opts := replication.ObjectOpts{ opts := replication.ObjectOpts{
Name: dobj.ObjectName, Name: dobj.ObjectName,
@@ -198,19 +198,19 @@ func checkReplicateDelete(ctx context.Context, bucket string, dobj ObjectToDelet
validReplStatus = true validReplStatus = true
} }
if oi.DeleteMarker && (validReplStatus || replicate) { if oi.DeleteMarker && (validReplStatus || replicate) {
return oi.DeleteMarker, true, sync return true, sync
} }
// can be the case that other cluster is down and duplicate `mc rm --vid` // can be the case that other cluster is down and duplicate `mc rm --vid`
// is issued - this still needs to be replicated back to the other target // is issued - this still needs to be replicated back to the other target
return oi.DeleteMarker, oi.VersionPurgeStatus == Pending || oi.VersionPurgeStatus == Failed, sync return oi.VersionPurgeStatus == Pending || oi.VersionPurgeStatus == Failed, sync
} }
tgt := globalBucketTargetSys.GetRemoteTargetClient(ctx, rcfg.RoleArn) tgt := globalBucketTargetSys.GetRemoteTargetClient(ctx, rcfg.RoleArn)
// the target online status should not be used here while deciding // the target online status should not be used here while deciding
// whether to replicate deletes as the target could be temporarily down // whether to replicate deletes as the target could be temporarily down
if tgt == nil { if tgt == nil {
return oi.DeleteMarker, false, false return false, false
} }
return oi.DeleteMarker, replicate, tgt.replicateSync return replicate, tgt.replicateSync
} }
// replicate deletes to the designated replication target if replication configuration // replicate deletes to the designated replication target if replication configuration
@@ -697,19 +697,25 @@ func replicateObject(ctx context.Context, objInfo ObjectInfo, objectAPI ObjectLa
if totalNodesCount == 0 { if totalNodesCount == 0 {
totalNodesCount = 1 // For standalone erasure coding totalNodesCount = 1 // For standalone erasure coding
} }
b := target.BandwidthLimit / int64(totalNodesCount)
var headerSize int var headerSize int
for k, v := range putOpts.Header() { for k, v := range putOpts.Header() {
headerSize += len(k) + len(v) headerSize += len(k) + len(v)
} }
// r takes over closing gr. opts := &bandwidth.MonitorReaderOptions{
r := bandwidth.NewMonitoredReader(ctx, globalBucketMonitor, objInfo.Bucket, objInfo.Name, gr, headerSize, b, target.BandwidthLimit) Bucket: objInfo.Bucket,
Object: objInfo.Name,
HeaderSize: headerSize,
BandwidthBytesPerSec: target.BandwidthLimit / int64(totalNodesCount),
ClusterBandwidth: target.BandwidthLimit,
}
r := bandwidth.NewMonitoredReader(ctx, globalBucketMonitor, gr, opts)
if _, err = c.PutObject(ctx, dest.Bucket, object, r, size, "", "", putOpts); err != nil { if _, err = c.PutObject(ctx, dest.Bucket, object, r, size, "", "", putOpts); err != nil {
replicationStatus = replication.Failed replicationStatus = replication.Failed
logger.LogIf(ctx, fmt.Errorf("Unable to replicate for object %s/%s(%s): %w", bucket, objInfo.Name, objInfo.VersionID, err)) logger.LogIf(ctx, fmt.Errorf("Unable to replicate for object %s/%s(%s): %w", bucket, objInfo.Name, objInfo.VersionID, err))
} }
defer r.Close()
} }
objInfo.UserDefined[xhttp.AmzBucketReplicationStatus] = replicationStatus.String() objInfo.UserDefined[xhttp.AmzBucketReplicationStatus] = replicationStatus.String()

View File

@@ -100,7 +100,7 @@ func (sys *BucketTargetSys) SetTarget(ctx context.Context, bucket string, tgt *m
if minio.ToErrorResponse(err).Code == "NoSuchBucket" { if minio.ToErrorResponse(err).Code == "NoSuchBucket" {
return BucketRemoteTargetNotFound{Bucket: tgt.TargetBucket} return BucketRemoteTargetNotFound{Bucket: tgt.TargetBucket}
} }
return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket} return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket, Err: err}
} }
if tgt.Type == madmin.ReplicationService { if tgt.Type == madmin.ReplicationService {
if !globalIsErasure { if !globalIsErasure {
@@ -111,7 +111,7 @@ func (sys *BucketTargetSys) SetTarget(ctx context.Context, bucket string, tgt *m
} }
vcfg, err := clnt.GetBucketVersioning(ctx, tgt.TargetBucket) vcfg, err := clnt.GetBucketVersioning(ctx, tgt.TargetBucket)
if err != nil { if err != nil {
return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket} return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket, Err: err}
} }
if vcfg.Status != string(versioning.Enabled) { if vcfg.Status != string(versioning.Enabled) {
return BucketRemoteTargetNotVersioned{Bucket: tgt.TargetBucket} return BucketRemoteTargetNotVersioned{Bucket: tgt.TargetBucket}
@@ -124,7 +124,7 @@ func (sys *BucketTargetSys) SetTarget(ctx context.Context, bucket string, tgt *m
if minio.ToErrorResponse(err).Code == "NoSuchBucket" { if minio.ToErrorResponse(err).Code == "NoSuchBucket" {
return BucketRemoteTargetNotFound{Bucket: tgt.TargetBucket} return BucketRemoteTargetNotFound{Bucket: tgt.TargetBucket}
} }
return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket} return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket, Err: err}
} }
if vcfg.Status != string(versioning.Enabled) { if vcfg.Status != string(versioning.Enabled) {
return BucketRemoteTargetNotVersioned{Bucket: tgt.TargetBucket} return BucketRemoteTargetNotVersioned{Bucket: tgt.TargetBucket}

View File

@@ -797,42 +797,39 @@ type actionMeta struct {
var applyActionsLogPrefix = color.Green("applyActions:") var applyActionsLogPrefix = color.Green("applyActions:")
// applyActions will apply lifecycle checks on to a scanned item. func (i *scannerItem) applyHealing(ctx context.Context, o ObjectLayer, meta actionMeta) (size int64) {
// The resulting size on disk will always be returned. if i.debug {
// The metadata will be compared to consensus on the object layer before any changes are applied. if meta.oi.VersionID != "" {
// If no metadata is supplied, -1 is returned if no action is taken. console.Debugf(applyActionsLogPrefix+" heal checking: %v/%v v(%s)\n", i.bucket, i.objectPath(), meta.oi.VersionID)
func (i *scannerItem) applyActions(ctx context.Context, o ObjectLayer, meta actionMeta) (size int64) { } else {
console.Debugf(applyActionsLogPrefix+" heal checking: %v/%v\n", i.bucket, i.objectPath())
}
}
healOpts := madmin.HealOpts{Remove: healDeleteDangling}
if meta.bitRotScan {
healOpts.ScanMode = madmin.HealDeepScan
}
res, err := o.HealObject(ctx, i.bucket, i.objectPath(), meta.oi.VersionID, healOpts)
if isErrObjectNotFound(err) || isErrVersionNotFound(err) {
return 0
}
if err != nil && !errors.Is(err, NotImplemented{}) {
logger.LogIf(ctx, err)
return 0
}
return res.ObjectSize
}
func (i *scannerItem) applyLifecycle(ctx context.Context, o ObjectLayer, meta actionMeta) (applied bool, size int64) {
size, err := meta.oi.GetActualSize() size, err := meta.oi.GetActualSize()
if i.debug { if i.debug {
logger.LogIf(ctx, err) logger.LogIf(ctx, err)
} }
if i.heal {
if i.debug {
if meta.oi.VersionID != "" {
console.Debugf(applyActionsLogPrefix+" heal checking: %v/%v v(%s)\n", i.bucket, i.objectPath(), meta.oi.VersionID)
} else {
console.Debugf(applyActionsLogPrefix+" heal checking: %v/%v\n", i.bucket, i.objectPath())
}
}
healOpts := madmin.HealOpts{Remove: healDeleteDangling}
if meta.bitRotScan {
healOpts.ScanMode = madmin.HealDeepScan
}
res, err := o.HealObject(ctx, i.bucket, i.objectPath(), meta.oi.VersionID, healOpts)
if isErrObjectNotFound(err) || isErrVersionNotFound(err) {
return 0
}
if err != nil && !errors.Is(err, NotImplemented{}) {
logger.LogIf(ctx, err)
return 0
}
size = res.ObjectSize
}
if i.lifeCycle == nil { if i.lifeCycle == nil {
if i.debug { if i.debug {
console.Debugf(applyActionsLogPrefix+" no lifecycle rules to apply: %q\n", i.objectPath()) console.Debugf(applyActionsLogPrefix+" no lifecycle rules to apply: %q\n", i.objectPath())
} }
return size return false, size
} }
versionID := meta.oi.VersionID versionID := meta.oi.VersionID
@@ -866,7 +863,7 @@ func (i *scannerItem) applyActions(ctx context.Context, o ObjectLayer, meta acti
if i.debug { if i.debug {
console.Debugf(applyActionsLogPrefix+" object not expirable: %q\n", i.objectPath()) console.Debugf(applyActionsLogPrefix+" object not expirable: %q\n", i.objectPath())
} }
return size return false, size
} }
obj, err := o.GetObjectInfo(ctx, i.bucket, i.objectPath(), ObjectOptions{ obj, err := o.GetObjectInfo(ctx, i.bucket, i.objectPath(), ObjectOptions{
@@ -878,19 +875,18 @@ func (i *scannerItem) applyActions(ctx context.Context, o ObjectLayer, meta acti
if !obj.DeleteMarker { // if this is not a delete marker log and return if !obj.DeleteMarker { // if this is not a delete marker log and return
// Do nothing - heal in the future. // Do nothing - heal in the future.
logger.LogIf(ctx, err) logger.LogIf(ctx, err)
return size return false, size
} }
case ObjectNotFound, VersionNotFound: case ObjectNotFound, VersionNotFound:
// object not found or version not found return 0 // object not found or version not found return 0
return 0 return false, 0
default: default:
// All other errors proceed. // All other errors proceed.
logger.LogIf(ctx, err) logger.LogIf(ctx, err)
return size return false, size
} }
} }
var applied bool
action = evalActionFromLifecycle(ctx, *i.lifeCycle, obj, i.debug) action = evalActionFromLifecycle(ctx, *i.lifeCycle, obj, i.debug)
if action != lifecycle.NoneAction { if action != lifecycle.NoneAction {
applied = applyLifecycleAction(ctx, action, o, obj) applied = applyLifecycleAction(ctx, action, o, obj)
@@ -899,9 +895,26 @@ func (i *scannerItem) applyActions(ctx context.Context, o ObjectLayer, meta acti
if applied { if applied {
switch action { switch action {
case lifecycle.TransitionAction, lifecycle.TransitionVersionAction: case lifecycle.TransitionAction, lifecycle.TransitionVersionAction:
default: // for all lifecycle actions that remove data return true, size
return 0
} }
// For all other lifecycle actions that remove data
return true, 0
}
return false, size
}
// applyActions will apply lifecycle checks on to a scanned item.
// The resulting size on disk will always be returned.
// The metadata will be compared to consensus on the object layer before any changes are applied.
// If no metadata is supplied, -1 is returned if no action is taken.
func (i *scannerItem) applyActions(ctx context.Context, o ObjectLayer, meta actionMeta) int64 {
applied, size := i.applyLifecycle(ctx, o, meta)
// For instance, an applied lifecycle means we remove/transitioned an object
// from the current deployment, which means we don't have to call healing
// routine even if we are asked to do via heal flag.
if !applied && i.heal {
size = i.applyHealing(ctx, o, meta)
} }
return size return size
} }

View File

@@ -522,7 +522,7 @@ func (d *dataUsageCache) save(ctx context.Context, store objectIO, name string)
dataUsageBucket, dataUsageBucket,
name, name,
NewPutObjReader(r), NewPutObjReader(r),
ObjectOptions{NoLock: true}) ObjectOptions{})
if isErrBucketNotFound(err) { if isErrBucketNotFound(err) {
return nil return nil
} }

View File

@@ -250,8 +250,8 @@ func (s *erasureSets) connectDisks() {
} }
disk.SetDiskLoc(s.poolIndex, setIndex, diskIndex) disk.SetDiskLoc(s.poolIndex, setIndex, diskIndex)
s.endpointStrings[setIndex*s.setDriveCount+diskIndex] = disk.String() s.endpointStrings[setIndex*s.setDriveCount+diskIndex] = disk.String()
s.erasureDisksMu.Unlock()
setsJustConnected[setIndex] = true setsJustConnected[setIndex] = true
s.erasureDisksMu.Unlock()
}(endpoint) }(endpoint)
} }

View File

@@ -233,10 +233,15 @@ func extractReqParams(r *http.Request) map[string]string {
region := globalServerRegion region := globalServerRegion
cred := getReqAccessCred(r, region) cred := getReqAccessCred(r, region)
principalID := cred.AccessKey
if cred.ParentUser != "" {
principalID = cred.ParentUser
}
// Success. // Success.
m := map[string]string{ m := map[string]string{
"region": region, "region": region,
"accessKey": cred.AccessKey, "principalId": principalID,
"sourceIPAddress": handlers.GetSourceIP(r), "sourceIPAddress": handlers.GetSourceIP(r),
// Add more fields here. // Add more fields here.
} }

View File

@@ -1704,7 +1704,7 @@ func (sys *IAMSys) PolicyDBGet(name string, isGroup bool, groups ...string) ([]s
// information in IAM (i.e sys.iam*Map) - this info is stored only in the STS // information in IAM (i.e sys.iam*Map) - this info is stored only in the STS
// generated credentials. Thus we skip looking up group memberships, user map, // generated credentials. Thus we skip looking up group memberships, user map,
// and group map and check the appropriate policy maps directly. // and group map and check the appropriate policy maps directly.
func (sys *IAMSys) policyDBGet(name string, isGroup bool) ([]string, error) { func (sys *IAMSys) policyDBGet(name string, isGroup bool) (policies []string, err error) {
if isGroup { if isGroup {
if sys.usersSysType == MinIOUsersSysType { if sys.usersSysType == MinIOUsersSysType {
g, ok := sys.iamGroupsMap[name] g, ok := sys.iamGroupsMap[name]
@@ -1719,8 +1719,7 @@ func (sys *IAMSys) policyDBGet(name string, isGroup bool) ([]string, error) {
} }
} }
mp := sys.iamGroupPolicyMap[name] return sys.iamGroupPolicyMap[name].toSlice(), nil
return mp.toSlice(), nil
} }
var u auth.Credentials var u auth.Credentials
@@ -1738,8 +1737,6 @@ func (sys *IAMSys) policyDBGet(name string, isGroup bool) ([]string, error) {
} }
} }
var policies []string
mp, ok := sys.iamUserPolicyMap[name] mp, ok := sys.iamUserPolicyMap[name]
if !ok { if !ok {
if u.ParentUser != "" { if u.ParentUser != "" {
@@ -1757,8 +1754,7 @@ func (sys *IAMSys) policyDBGet(name string, isGroup bool) ([]string, error) {
continue continue
} }
p := sys.iamGroupPolicyMap[group] policies = append(policies, sys.iamGroupPolicyMap[group].toSlice()...)
policies = append(policies, p.toSlice()...)
} }
return policies, nil return policies, nil
@@ -1788,8 +1784,9 @@ func (sys *IAMSys) IsAllowedServiceAccount(args iampolicy.Args, parent string) b
} }
// Check policy for this service account. // Check policy for this service account.
svcPolicies, err := sys.PolicyDBGet(args.AccountName, false) svcPolicies, err := sys.PolicyDBGet(parent, false, args.Groups...)
if err != nil { if err != nil {
logger.LogIf(GlobalContext, err)
return false return false
} }
@@ -2072,7 +2069,7 @@ func (sys *IAMSys) IsAllowed(args iampolicy.Args) bool {
} }
// Continue with the assumption of a regular user // Continue with the assumption of a regular user
policies, err := sys.PolicyDBGet(args.AccountName, false) policies, err := sys.PolicyDBGet(args.AccountName, false, args.Groups...)
if err != nil { if err != nil {
return false return false
} }

View File

@@ -81,6 +81,15 @@ type MapClaims struct {
jwtgo.MapClaims jwtgo.MapClaims
} }
// GetAccessKey will return the access key.
// If nil an empty string will be returned.
func (c *MapClaims) GetAccessKey() string {
if c == nil {
return ""
}
return c.AccessKey
}
// NewStandardClaims - initializes standard claims // NewStandardClaims - initializes standard claims
func NewStandardClaims() *StandardClaims { func NewStandardClaims() *StandardClaims {
return &StandardClaims{} return &StandardClaims{}

View File

@@ -1368,7 +1368,7 @@ func (args eventArgs) ToEvent(escape bool) event.Event {
AwsRegion: args.ReqParams["region"], AwsRegion: args.ReqParams["region"],
EventTime: eventTime.Format(event.AMZTimeFormat), EventTime: eventTime.Format(event.AMZTimeFormat),
EventName: args.EventName, EventName: args.EventName,
UserIdentity: event.Identity{PrincipalID: args.ReqParams["accessKey"]}, UserIdentity: event.Identity{PrincipalID: args.ReqParams["principalId"]},
RequestParameters: args.ReqParams, RequestParameters: args.ReqParams,
ResponseElements: respElements, ResponseElements: respElements,
S3: event.Metadata{ S3: event.Metadata{
@@ -1376,7 +1376,7 @@ func (args eventArgs) ToEvent(escape bool) event.Event {
ConfigurationID: "Config", ConfigurationID: "Config",
Bucket: event.Bucket{ Bucket: event.Bucket{
Name: args.BucketName, Name: args.BucketName,
OwnerIdentity: event.Identity{PrincipalID: args.ReqParams["accessKey"]}, OwnerIdentity: event.Identity{PrincipalID: args.ReqParams["principalId"]},
ARN: policy.ResourceARNPrefix + args.BucketName, ARN: policy.ResourceARNPrefix + args.BucketName,
}, },
Object: event.Object{ Object: event.Object{

View File

@@ -426,7 +426,7 @@ func (e BucketRemoteTargetNotFound) Error() string {
type BucketRemoteConnectionErr GenericError type BucketRemoteConnectionErr GenericError
func (e BucketRemoteConnectionErr) Error() string { func (e BucketRemoteConnectionErr) Error() string {
return "Remote service endpoint or target bucket not available: " + e.Bucket return fmt.Sprintf("Remote service endpoint or target bucket not available: %s \n\t%s", e.Bucket, e.Err.Error())
} }
// BucketRemoteAlreadyExists remote already exists for this target type. // BucketRemoteAlreadyExists remote already exists for this target type.

View File

@@ -2371,8 +2371,20 @@ func (api objectAPIHandlers) PutObjectPartHandler(w http.ResponseWriter, r *http
} }
etag := partInfo.ETag etag := partInfo.ETag
if isEncrypted { switch kind, encrypted := crypto.IsEncrypted(mi.UserDefined); {
etag = tryDecryptETag(objectEncryptionKey[:], partInfo.ETag, crypto.SSEC.IsRequested(r.Header)) case encrypted:
switch kind {
case crypto.S3:
w.Header().Set(xhttp.AmzServerSideEncryption, xhttp.AmzEncryptionAES)
etag = tryDecryptETag(objectEncryptionKey[:], etag, false)
case crypto.SSEC:
w.Header().Set(xhttp.AmzServerSideEncryptionCustomerAlgorithm, r.Header.Get(xhttp.AmzServerSideEncryptionCustomerAlgorithm))
w.Header().Set(xhttp.AmzServerSideEncryptionCustomerKeyMD5, r.Header.Get(xhttp.AmzServerSideEncryptionCustomerKeyMD5))
if len(etag) >= 32 && strings.Count(etag, "-") != 1 {
etag = etag[len(etag)-32:]
}
}
} }
// We must not use the http.Header().Set method here because some (broken) // We must not use the http.Header().Set method here because some (broken)
@@ -2817,7 +2829,8 @@ func (api objectAPIHandlers) DeleteObjectHandler(w http.ResponseWriter, r *http.
VersionID: opts.VersionID, VersionID: opts.VersionID,
}) })
} }
_, replicateDel, replicateSync := checkReplicateDelete(ctx, bucket, ObjectToDelete{ObjectName: object, VersionID: opts.VersionID}, goi, gerr)
replicateDel, replicateSync := checkReplicateDelete(ctx, bucket, ObjectToDelete{ObjectName: object, VersionID: opts.VersionID}, goi, gerr)
if replicateDel { if replicateDel {
if opts.VersionID != "" { if opts.VersionID != "" {
opts.VersionPurgeStatus = Pending opts.VersionPurgeStatus = Pending
@@ -2825,6 +2838,7 @@ func (api objectAPIHandlers) DeleteObjectHandler(w http.ResponseWriter, r *http.
opts.DeleteMarkerReplicationStatus = string(replication.Pending) opts.DeleteMarkerReplicationStatus = string(replication.Pending)
} }
} }
vID := opts.VersionID vID := opts.VersionID
if r.Header.Get(xhttp.AmzBucketReplicationStatus) == replication.Replica.String() { if r.Header.Get(xhttp.AmzBucketReplicationStatus) == replication.Replica.String() {
// check if replica has permission to be deleted. // check if replica has permission to be deleted.

View File

@@ -340,9 +340,8 @@ func (client *storageRESTClient) CreateFile(ctx context.Context, volume, path st
if err != nil { if err != nil {
return err return err
} }
waitReader, err := waitForHTTPResponse(respBody) _, err = waitForHTTPResponse(respBody)
defer http.DrainBody(ioutil.NopCloser(waitReader)) defer http.DrainBody(respBody)
defer respBody.Close()
return err return err
} }

View File

@@ -226,7 +226,7 @@ func (web *webAPIHandlers) MakeBucket(r *http.Request, args *MakeBucketArgs, rep
reply.UIVersion = Version reply.UIVersion = Version
reqParams := extractReqParams(r) reqParams := extractReqParams(r)
reqParams["accessKey"] = claims.AccessKey reqParams["accessKey"] = claims.GetAccessKey()
sendEvent(eventArgs{ sendEvent(eventArgs{
EventName: event.BucketCreated, EventName: event.BucketCreated,
@@ -723,7 +723,7 @@ func (web *webAPIHandlers) RemoveObject(r *http.Request, args *RemoveObjectArgs,
) )
reqParams := extractReqParams(r) reqParams := extractReqParams(r)
reqParams["accessKey"] = claims.AccessKey reqParams["accessKey"] = claims.GetAccessKey()
sourceIP := handlers.GetSourceIP(r) sourceIP := handlers.GetSourceIP(r)
next: next:
@@ -767,7 +767,7 @@ next:
} }
if hasReplicationRules(ctx, args.BucketName, []ObjectToDelete{{ObjectName: objectName}}) || hasLifecycleConfig { if hasReplicationRules(ctx, args.BucketName, []ObjectToDelete{{ObjectName: objectName}}) || hasLifecycleConfig {
goi, gerr = getObjectInfoFn(ctx, args.BucketName, objectName, opts) goi, gerr = getObjectInfoFn(ctx, args.BucketName, objectName, opts)
if _, replicateDel, replicateSync = checkReplicateDelete(ctx, args.BucketName, ObjectToDelete{ if replicateDel, replicateSync = checkReplicateDelete(ctx, args.BucketName, ObjectToDelete{
ObjectName: objectName, ObjectName: objectName,
VersionID: goi.VersionID, VersionID: goi.VersionID,
}, goi, gerr); replicateDel { }, goi, gerr); replicateDel {
@@ -903,7 +903,7 @@ next:
} }
} }
} }
_, replicateDel, _ := checkReplicateDelete(ctx, args.BucketName, ObjectToDelete{ObjectName: obj.Name, VersionID: obj.VersionID}, obj, nil) replicateDel, _ := checkReplicateDelete(ctx, args.BucketName, ObjectToDelete{ObjectName: obj.Name, VersionID: obj.VersionID}, obj, nil)
// since versioned delete is not available on web browser, yet - this is a simple DeleteMarker replication // since versioned delete is not available on web browser, yet - this is a simple DeleteMarker replication
objToDel := ObjectToDelete{ObjectName: obj.Name} objToDel := ObjectToDelete{ObjectName: obj.Name}
if replicateDel { if replicateDel {
@@ -1340,7 +1340,7 @@ func (web *webAPIHandlers) Upload(w http.ResponseWriter, r *http.Request) {
} }
reqParams := extractReqParams(r) reqParams := extractReqParams(r)
reqParams["accessKey"] = claims.AccessKey reqParams["accessKey"] = claims.GetAccessKey()
// Notify object created event. // Notify object created event.
sendEvent(eventArgs{ sendEvent(eventArgs{
@@ -1529,7 +1529,7 @@ func (web *webAPIHandlers) Download(w http.ResponseWriter, r *http.Request) {
} }
reqParams := extractReqParams(r) reqParams := extractReqParams(r)
reqParams["accessKey"] = claims.AccessKey reqParams["accessKey"] = claims.GetAccessKey()
// Notify object accessed via a GET request. // Notify object accessed via a GET request.
sendEvent(eventArgs{ sendEvent(eventArgs{
@@ -1684,7 +1684,7 @@ func (web *webAPIHandlers) DownloadZip(w http.ResponseWriter, r *http.Request) {
defer archive.Close() defer archive.Close()
reqParams := extractReqParams(r) reqParams := extractReqParams(r)
reqParams["accessKey"] = claims.AccessKey reqParams["accessKey"] = claims.GetAccessKey()
respElements := extractRespElements(w) respElements := extractRespElements(w)
for i, object := range args.Objects { for i, object := range args.Objects {

View File

@@ -347,6 +347,8 @@ func (s *xlStorage) IsLocal() bool {
// Retrieve location indexes. // Retrieve location indexes.
func (s *xlStorage) GetDiskLoc() (poolIdx, setIdx, diskIdx int) { func (s *xlStorage) GetDiskLoc() (poolIdx, setIdx, diskIdx int) {
s.RLock()
defer s.RUnlock()
// If unset, see if we can locate it. // If unset, see if we can locate it.
if s.poolIndex < 0 || s.setIndex < 0 || s.diskIndex < 0 { if s.poolIndex < 0 || s.setIndex < 0 || s.diskIndex < 0 {
return getXLDiskLoc(s.diskID) return getXLDiskLoc(s.diskID)
@@ -1615,6 +1617,9 @@ func (s *xlStorage) CheckFile(ctx context.Context, volume string, path string) e
if err != nil { if err != nil {
return err return err
} }
s.RLock()
formatLegacy := s.formatLegacy
s.RUnlock()
var checkFile func(p string) error var checkFile func(p string) error
checkFile = func(p string) error { checkFile = func(p string) error {
@@ -1626,10 +1631,10 @@ func (s *xlStorage) CheckFile(ctx context.Context, volume string, path string) e
if err := checkPathLength(filePath); err != nil { if err := checkPathLength(filePath); err != nil {
return err return err
} }
st, _ := Lstat(filePath) st, _ := Lstat(filePath)
if st == nil { if st == nil {
if !s.formatLegacy {
if !formatLegacy {
return errPathNotFound return errPathNotFound
} }
@@ -1880,10 +1885,13 @@ func (s *xlStorage) RenameData(ctx context.Context, srcVolume, srcPath, dataDir,
legacyPreserved = true legacyPreserved = true
} }
} else { } else {
s.RLock()
formatLegacy := s.formatLegacy
s.RUnlock()
// It is possible that some drives may not have `xl.meta` file // It is possible that some drives may not have `xl.meta` file
// in such scenarios verify if atleast `part.1` files exist // in such scenarios verify if atleast `part.1` files exist
// to verify for legacy version. // to verify for legacy version.
if s.formatLegacy { if formatLegacy {
// We only need this code if we are moving // We only need this code if we are moving
// from `xl.json` to `xl.meta`, we can avoid // from `xl.json` to `xl.meta`, we can avoid
// one extra readdir operation here for all // one extra readdir operation here for all

View File

@@ -5,7 +5,7 @@ version: '3.7'
# it through port 9000. # it through port 9000.
services: services:
minio1: minio1:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
volumes: volumes:
- data1-1:/data1 - data1-1:/data1
- data1-2:/data2 - data1-2:/data2
@@ -22,7 +22,7 @@ services:
retries: 3 retries: 3
minio2: minio2:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
volumes: volumes:
- data2-1:/data1 - data2-1:/data1
- data2-2:/data2 - data2-2:/data2
@@ -39,7 +39,7 @@ services:
retries: 3 retries: 3
minio3: minio3:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
volumes: volumes:
- data3-1:/data1 - data3-1:/data1
- data3-2:/data2 - data3-2:/data2
@@ -56,7 +56,7 @@ services:
retries: 3 retries: 3
minio4: minio4:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
volumes: volumes:
- data4-1:/data1 - data4-1:/data1
- data4-2:/data2 - data4-2:/data2

View File

@@ -2,7 +2,7 @@ version: '3.7'
services: services:
minio1: minio1:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio1 hostname: minio1
volumes: volumes:
- minio1-data:/export - minio1-data:/export
@@ -29,7 +29,7 @@ services:
retries: 3 retries: 3
minio2: minio2:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio2 hostname: minio2
volumes: volumes:
- minio2-data:/export - minio2-data:/export
@@ -56,7 +56,7 @@ services:
retries: 3 retries: 3
minio3: minio3:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio3 hostname: minio3
volumes: volumes:
- minio3-data:/export - minio3-data:/export
@@ -83,7 +83,7 @@ services:
retries: 3 retries: 3
minio4: minio4:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio4 hostname: minio4
volumes: volumes:
- minio4-data:/export - minio4-data:/export

View File

@@ -2,7 +2,7 @@ version: '3.7'
services: services:
minio1: minio1:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio1 hostname: minio1
volumes: volumes:
- minio1-data:/export - minio1-data:/export
@@ -33,7 +33,7 @@ services:
retries: 3 retries: 3
minio2: minio2:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio2 hostname: minio2
volumes: volumes:
- minio2-data:/export - minio2-data:/export
@@ -64,7 +64,7 @@ services:
retries: 3 retries: 3
minio3: minio3:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio3 hostname: minio3
volumes: volumes:
- minio3-data:/export - minio3-data:/export
@@ -95,7 +95,7 @@ services:
retries: 3 retries: 3
minio4: minio4:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio4 hostname: minio4
volumes: volumes:
- minio4-data:/export - minio4-data:/export

1
go.mod
View File

@@ -77,6 +77,7 @@ require (
github.com/tidwall/gjson v1.6.8 github.com/tidwall/gjson v1.6.8
github.com/tidwall/sjson v1.0.4 github.com/tidwall/sjson v1.0.4
github.com/tinylib/msgp v1.1.3 github.com/tinylib/msgp v1.1.3
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31 // indirect
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a
github.com/willf/bitset v1.1.11 // indirect github.com/willf/bitset v1.1.11 // indirect
github.com/willf/bloom v2.0.3+incompatible github.com/willf/bloom v2.0.3+incompatible

2
go.sum
View File

@@ -596,6 +596,8 @@ github.com/tinylib/msgp v1.1.3 h1:3giwAkmtaEDLSV0MdO1lDLuPgklgPzmk8H9+So2BVfA=
github.com/tinylib/msgp v1.1.3/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/tinylib/msgp v1.1.3/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8 h1:ndzgwNDnKIqyCvHTXaCqh9KlOWKvBry6nuXMJmonVsE= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8 h1:ndzgwNDnKIqyCvHTXaCqh9KlOWKvBry6nuXMJmonVsE=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31 h1:OXcKh35JaYsGMRzpvFkLv/MEyPuL49CThT1pZ8aSml4=
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31/go.mod h1:onvgF043R+lC5RZ8IT9rBXDaEDnpnw/Cl+HFiw+v/7Q=
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=

View File

@@ -123,8 +123,12 @@ func (m *Monitor) getReport(selectBucket SelectionFunction) *bandwidth.Report {
if !selectBucket(bucket) { if !selectBucket(bucket) {
continue continue
} }
bucketThrottle, ok := m.bucketThrottle[bucket]
if !ok {
continue
}
report.BucketStats[bucket] = bandwidth.Details{ report.BucketStats[bucket] = bandwidth.Details{
LimitInBytesPerSecond: m.bucketThrottle[bucket].clusterBandwidth, LimitInBytesPerSecond: bucketThrottle.clusterBandwidth,
CurrentBandwidthInBytesPerSecond: bucketMeasurement.getExpMovingAvgBytesPerSecond(), CurrentBandwidthInBytesPerSecond: bucketMeasurement.getExpMovingAvgBytesPerSecond(),
} }
} }

View File

@@ -25,62 +25,61 @@ import (
// MonitoredReader monitors the bandwidth // MonitoredReader monitors the bandwidth
type MonitoredReader struct { type MonitoredReader struct {
bucket string // Token to track bucket opts *MonitorReaderOptions
bucketMeasurement *bucketMeasurement // bucket measurement object bucketMeasurement *bucketMeasurement // bucket measurement object
object string // Token to track object reader io.Reader // Reader to wrap
reader io.ReadCloser // Reader to wrap
lastStop time.Time // Last timestamp for a measurement lastStop time.Time // Last timestamp for a measurement
headerSize int // Size of the header not captured by reader
throttle *throttle // throttle the rate at which replication occur throttle *throttle // throttle the rate at which replication occur
monitor *Monitor // Monitor reference monitor *Monitor // Monitor reference
closed bool // Reader is closed lastErr error // last error reported, if this non-nil all reads will fail.
} }
// NewMonitoredReader returns a io.ReadCloser that reports bandwidth details. // MonitorReaderOptions provides configurable options for monitor reader implementation.
// The supplied reader will be closed. type MonitorReaderOptions struct {
func NewMonitoredReader(ctx context.Context, monitor *Monitor, bucket string, object string, reader io.ReadCloser, headerSize int, bandwidthBytesPerSecond int64, clusterBandwidth int64) *MonitoredReader { Bucket string
Object string
HeaderSize int
BandwidthBytesPerSec int64
ClusterBandwidth int64
}
// NewMonitoredReader returns a io.Reader that reports bandwidth details.
func NewMonitoredReader(ctx context.Context, monitor *Monitor, reader io.Reader, opts *MonitorReaderOptions) *MonitoredReader {
timeNow := time.Now() timeNow := time.Now()
b := monitor.track(bucket, object, timeNow) b := monitor.track(opts.Bucket, opts.Object, timeNow)
return &MonitoredReader{ return &MonitoredReader{
bucket: bucket, opts: opts,
object: object,
bucketMeasurement: b, bucketMeasurement: b,
reader: reader, reader: reader,
lastStop: timeNow, lastStop: timeNow,
headerSize: headerSize, throttle: monitor.throttleBandwidth(ctx, opts.Bucket, opts.BandwidthBytesPerSec, opts.ClusterBandwidth),
throttle: monitor.throttleBandwidth(ctx, bucket, bandwidthBytesPerSecond, clusterBandwidth),
monitor: monitor, monitor: monitor,
} }
} }
// Read wraps the read reader // Read wraps the read reader
func (m *MonitoredReader) Read(p []byte) (n int, err error) { func (m *MonitoredReader) Read(p []byte) (n int, err error) {
if m.closed { if m.lastErr != nil {
err = io.ErrClosedPipe err = m.lastErr
return return
} }
p = p[:m.throttle.GetLimitForBytes(int64(len(p)))] p = p[:m.throttle.GetLimitForBytes(int64(len(p)))]
n, err = m.reader.Read(p) n, err = m.reader.Read(p)
stop := time.Now() stop := time.Now()
update := uint64(n + m.headerSize) update := uint64(n + m.opts.HeaderSize)
m.bucketMeasurement.incrementBytes(update) m.bucketMeasurement.incrementBytes(update)
m.lastStop = stop m.lastStop = stop
unused := len(p) - (n + m.headerSize) unused := len(p) - (n + m.opts.HeaderSize)
m.headerSize = 0 // Set to 0 post first read m.opts.HeaderSize = 0 // Set to 0 post first read
if unused > 0 { if unused > 0 {
m.throttle.ReleaseUnusedBandwidth(int64(unused)) m.throttle.ReleaseUnusedBandwidth(int64(unused))
} }
if err != nil {
m.lastErr = err
}
return return
} }
// Close stops tracking the io
func (m *MonitoredReader) Close() error {
if m.closed {
return nil
}
m.closed = true
return m.reader.Close()
}

View File

@@ -18,6 +18,7 @@ package lifecycle
import ( import (
"encoding/xml" "encoding/xml"
"fmt"
"io" "io"
"strings" "strings"
"time" "time"
@@ -71,7 +72,8 @@ func (lc *Lifecycle) UnmarshalXML(d *xml.Decoder, start xml.StartElement) (err e
switch start.Name.Local { switch start.Name.Local {
case "LifecycleConfiguration", "BucketLifecycleConfiguration": case "LifecycleConfiguration", "BucketLifecycleConfiguration":
default: default:
return errUnknownXMLTag return xml.UnmarshalError(fmt.Sprintf("expected element type <LifecycleConfiguration>/<BucketLifecycleConfiguration> but have <%s>",
start.Name.Local))
} }
for { for {
// Read tokens from the XML document in a stream. // Read tokens from the XML document in a stream.
@@ -93,7 +95,7 @@ func (lc *Lifecycle) UnmarshalXML(d *xml.Decoder, start xml.StartElement) (err e
} }
lc.Rules = append(lc.Rules, r) lc.Rules = append(lc.Rules, r)
default: default:
return errUnknownXMLTag return xml.UnmarshalError(fmt.Sprintf("expected element type <Rule> but have <%s>", se.Name.Local))
} }
} }
} }

View File

@@ -489,6 +489,41 @@ type ObjectLegalHold struct {
Status LegalHoldStatus `xml:"Status,omitempty"` Status LegalHoldStatus `xml:"Status,omitempty"`
} }
// UnmarshalXML - decodes XML data.
func (l *ObjectLegalHold) UnmarshalXML(d *xml.Decoder, start xml.StartElement) (err error) {
switch start.Name.Local {
case "LegalHold", "ObjectLockLegalHold":
default:
return xml.UnmarshalError(fmt.Sprintf("expected element type <LegalHold>/<ObjectLockLegalHold> but have <%s>",
start.Name.Local))
}
for {
// Read tokens from the XML document in a stream.
t, err := d.Token()
if err != nil {
if err == io.EOF {
break
}
return err
}
switch se := t.(type) {
case xml.StartElement:
switch se.Name.Local {
case "Status":
var st LegalHoldStatus
if err = d.DecodeElement(&st, &se); err != nil {
return err
}
l.Status = st
default:
return xml.UnmarshalError(fmt.Sprintf("expected element type <Status> but have <%s>", se.Name.Local))
}
}
}
return nil
}
// IsEmpty returns true if struct is empty // IsEmpty returns true if struct is empty
func (l *ObjectLegalHold) IsEmpty() bool { func (l *ObjectLegalHold) IsEmpty() bool {
return !l.Status.Valid() return !l.Status.Valid()

View File

@@ -18,6 +18,7 @@ package lock
import ( import (
"encoding/xml" "encoding/xml"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
@@ -467,6 +468,23 @@ func TestParseObjectLegalHold(t *testing.T) {
expectedErr: nil, expectedErr: nil,
expectErr: false, expectErr: false,
}, },
{
value: `<?xml version="1.0" encoding="UTF-8"?><ObjectLockLegalHold xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Status>ON</Status></ObjectLockLegalHold>`,
expectedErr: nil,
expectErr: false,
},
// invalid Status key
{
value: `<?xml version="1.0" encoding="UTF-8"?><ObjectLockLegalHold xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><MyStatus>ON</MyStatus></ObjectLockLegalHold>`,
expectedErr: errors.New("expected element type <Status> but have <MyStatus>"),
expectErr: true,
},
// invalid XML attr
{
value: `<?xml version="1.0" encoding="UTF-8"?><UnknownLegalHold xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Status>ON</Status></UnknownLegalHold>`,
expectedErr: errors.New("expected element type <LegalHold>/<ObjectLockLegalHold> but have <UnknownLegalHold>"),
expectErr: true,
},
{ {
value: `<?xml version="1.0" encoding="UTF-8"?><LegalHold xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Status>On</Status></LegalHold>`, value: `<?xml version="1.0" encoding="UTF-8"?><LegalHold xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Status>On</Status></LegalHold>`,
expectedErr: ErrMalformedXML, expectedErr: ErrMalformedXML,

View File

@@ -110,10 +110,18 @@ const (
// AWSUsername - user friendly name, in MinIO this value is same as your user Access Key. // AWSUsername - user friendly name, in MinIO this value is same as your user Access Key.
AWSUsername Key = "aws:username" AWSUsername Key = "aws:username"
// S3SignatureVersion - identifies the version of AWS Signature that you want to support for authenticated requests.
S3SignatureVersion = "s3:signatureversion"
// S3AuthType - optionally use this condition key to restrict incoming requests to use a specific authentication method.
S3AuthType = "s3:authType"
) )
// AllSupportedKeys - is list of all all supported keys. // AllSupportedKeys - is list of all all supported keys.
var AllSupportedKeys = append([]Key{ var AllSupportedKeys = append([]Key{
S3SignatureVersion,
S3AuthType,
S3XAmzCopySource, S3XAmzCopySource,
S3XAmzServerSideEncryption, S3XAmzServerSideEncryption,
S3XAmzServerSideEncryptionCustomerAlgorithm, S3XAmzServerSideEncryptionCustomerAlgorithm,
@@ -144,6 +152,8 @@ var AllSupportedKeys = append([]Key{
// CommonKeys - is list of all common condition keys. // CommonKeys - is list of all common condition keys.
var CommonKeys = append([]Key{ var CommonKeys = append([]Key{
S3SignatureVersion,
S3AuthType,
S3XAmzContentSha256, S3XAmzContentSha256,
S3LocationConstraint, S3LocationConstraint,
AWSReferer, AWSReferer,

View File

@@ -739,6 +739,152 @@ func TestCSVQueries2(t *testing.T) {
} }
} }
func TestCSVQueries3(t *testing.T) {
input := `na.me,qty,CAST
apple,1,true
mango,3,false
`
var testTable = []struct {
name string
query string
requestXML []byte // override request XML
wantResult string
}{
{
name: "Select a column containing dot",
query: `select "na.me" from S3Object s`,
wantResult: `apple
mango`,
},
{
name: "Select column containing dot with table name prefix",
query: `select count(S3Object."na.me") from S3Object`,
wantResult: `2`,
},
{
name: "Select column containing dot with table alias prefix",
query: `select s."na.me" from S3Object as s`,
wantResult: `apple
mango`,
},
{
name: "Select column simplest",
query: `select qty from S3Object`,
wantResult: `1
3`,
},
{
name: "Select column with table name prefix",
query: `select S3Object.qty from S3Object`,
wantResult: `1
3`,
},
{
name: "Select column without table alias",
query: `select qty from S3Object s`,
wantResult: `1
3`,
},
{
name: "Select column with table alias",
query: `select s.qty from S3Object s`,
wantResult: `1
3`,
},
{
name: "Select reserved word column",
query: `select "CAST" from s3object`,
wantResult: `true
false`,
},
{
name: "Select reserved word column with table alias",
query: `select S3Object."CAST" from s3object`,
wantResult: `true
false`,
},
{
name: "Select reserved word column with unused table alias",
query: `select "CAST" from s3object s`,
wantResult: `true
false`,
},
{
name: "Select reserved word column with table alias",
query: `select s."CAST" from s3object s`,
wantResult: `true
false`,
},
{
name: "Select reserved word column with table alias",
query: `select NOT CAST(s."CAST" AS Bool) from s3object s`,
wantResult: `false
true`,
},
}
defRequest := `<?xml version="1.0" encoding="UTF-8"?>
<SelectObjectContentRequest>
<Expression>%s</Expression>
<ExpressionType>SQL</ExpressionType>
<InputSerialization>
<CompressionType>NONE</CompressionType>
<CSV>
<FileHeaderInfo>USE</FileHeaderInfo>
<QuoteCharacter>"</QuoteCharacter>
</CSV>
</InputSerialization>
<OutputSerialization>
<CSV/>
</OutputSerialization>
<RequestProgress>
<Enabled>FALSE</Enabled>
</RequestProgress>
</SelectObjectContentRequest>`
for _, testCase := range testTable {
t.Run(testCase.name, func(t *testing.T) {
testReq := testCase.requestXML
if len(testReq) == 0 {
testReq = []byte(fmt.Sprintf(defRequest, testCase.query))
}
s3Select, err := NewS3Select(bytes.NewReader(testReq))
if err != nil {
t.Fatal(err)
}
if err = s3Select.Open(func(offset, length int64) (io.ReadCloser, error) {
return ioutil.NopCloser(bytes.NewBufferString(input)), nil
}); err != nil {
t.Fatal(err)
}
w := &testResponseWriter{}
s3Select.Evaluate(w)
s3Select.Close()
resp := http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(w.response)),
ContentLength: int64(len(w.response)),
}
res, err := minio.NewSelectResults(&resp, "testbucket")
if err != nil {
t.Error(err)
return
}
got, err := ioutil.ReadAll(res)
if err != nil {
t.Error(err)
return
}
gotS := strings.TrimSpace(string(got))
if gotS != testCase.wantResult {
t.Errorf("received response does not match with expected reply.\nQuery: %s\n=====\ngot: %s\n=====\nwant: %s\n=====\n", testCase.query, gotS, testCase.wantResult)
}
})
}
}
func TestCSVInput(t *testing.T) { func TestCSVInput(t *testing.T) {
var testTable = []struct { var testTable = []struct {
requestXML []byte requestXML []byte

View File

@@ -63,7 +63,7 @@ func newAggVal(fn FuncName) *aggVal {
// current row and stores the result. // current row and stores the result.
// //
// On success, it returns (nil, nil). // On success, it returns (nil, nil).
func (e *FuncExpr) evalAggregationNode(r Record) error { func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error {
// It is assumed that this function is called only when // It is assumed that this function is called only when
// `e` is an aggregation function. // `e` is an aggregation function.
@@ -77,13 +77,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error {
return nil return nil
} }
val, err = e.Count.ExprArg.evalNode(r) val, err = e.Count.ExprArg.evalNode(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
// Evaluate the (only) argument // Evaluate the (only) argument
val, err = e.SFunc.ArgsList[0].evalNode(r) val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
@@ -149,13 +149,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error {
return err return err
} }
func (e *AliasedExpression) aggregateRow(r Record) error { func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error {
return e.Expression.aggregateRow(r) return e.Expression.aggregateRow(r, tableAlias)
} }
func (e *Expression) aggregateRow(r Record) error { func (e *Expression) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.And { for _, ex := range e.And {
err := ex.aggregateRow(r) err := ex.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
@@ -163,9 +163,9 @@ func (e *Expression) aggregateRow(r Record) error {
return nil return nil
} }
func (e *ListExpr) aggregateRow(r Record) error { func (e *ListExpr) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.Elements { for _, ex := range e.Elements {
err := ex.aggregateRow(r) err := ex.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
@@ -173,9 +173,9 @@ func (e *ListExpr) aggregateRow(r Record) error {
return nil return nil
} }
func (e *AndCondition) aggregateRow(r Record) error { func (e *AndCondition) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.Condition { for _, ex := range e.Condition {
err := ex.aggregateRow(r) err := ex.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
@@ -183,15 +183,15 @@ func (e *AndCondition) aggregateRow(r Record) error {
return nil return nil
} }
func (e *Condition) aggregateRow(r Record) error { func (e *Condition) aggregateRow(r Record, tableAlias string) error {
if e.Operand != nil { if e.Operand != nil {
return e.Operand.aggregateRow(r) return e.Operand.aggregateRow(r, tableAlias)
} }
return e.Not.aggregateRow(r) return e.Not.aggregateRow(r, tableAlias)
} }
func (e *ConditionOperand) aggregateRow(r Record) error { func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error {
err := e.Operand.aggregateRow(r) err := e.Operand.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
@@ -202,38 +202,38 @@ func (e *ConditionOperand) aggregateRow(r Record) error {
switch { switch {
case e.ConditionRHS.Compare != nil: case e.ConditionRHS.Compare != nil:
return e.ConditionRHS.Compare.Operand.aggregateRow(r) return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias)
case e.ConditionRHS.Between != nil: case e.ConditionRHS.Between != nil:
err = e.ConditionRHS.Between.Start.aggregateRow(r) err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
return e.ConditionRHS.Between.End.aggregateRow(r) return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias)
case e.ConditionRHS.In != nil: case e.ConditionRHS.In != nil:
elt := e.ConditionRHS.In.ListExpression elt := e.ConditionRHS.In.ListExpression
err = elt.aggregateRow(r) err = elt.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
case e.ConditionRHS.Like != nil: case e.ConditionRHS.Like != nil:
err = e.ConditionRHS.Like.Pattern.aggregateRow(r) err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r) return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias)
default: default:
return errInvalidASTNode return errInvalidASTNode
} }
} }
func (e *Operand) aggregateRow(r Record) error { func (e *Operand) aggregateRow(r Record, tableAlias string) error {
err := e.Left.aggregateRow(r) err := e.Left.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
for _, rt := range e.Right { for _, rt := range e.Right {
err = rt.Right.aggregateRow(r) err = rt.Right.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
@@ -241,13 +241,13 @@ func (e *Operand) aggregateRow(r Record) error {
return nil return nil
} }
func (e *MultOp) aggregateRow(r Record) error { func (e *MultOp) aggregateRow(r Record, tableAlias string) error {
err := e.Left.aggregateRow(r) err := e.Left.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
for _, rt := range e.Right { for _, rt := range e.Right {
err = rt.Right.aggregateRow(r) err = rt.Right.aggregateRow(r, tableAlias)
if err != nil { if err != nil {
return err return err
} }
@@ -255,29 +255,29 @@ func (e *MultOp) aggregateRow(r Record) error {
return nil return nil
} }
func (e *UnaryTerm) aggregateRow(r Record) error { func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error {
if e.Negated != nil { if e.Negated != nil {
return e.Negated.Term.aggregateRow(r) return e.Negated.Term.aggregateRow(r, tableAlias)
} }
return e.Primary.aggregateRow(r) return e.Primary.aggregateRow(r, tableAlias)
} }
func (e *PrimaryTerm) aggregateRow(r Record) error { func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error {
switch { switch {
case e.ListExpr != nil: case e.ListExpr != nil:
return e.ListExpr.aggregateRow(r) return e.ListExpr.aggregateRow(r, tableAlias)
case e.SubExpression != nil: case e.SubExpression != nil:
return e.SubExpression.aggregateRow(r) return e.SubExpression.aggregateRow(r, tableAlias)
case e.FuncCall != nil: case e.FuncCall != nil:
return e.FuncCall.aggregateRow(r) return e.FuncCall.aggregateRow(r, tableAlias)
} }
return nil return nil
} }
func (e *FuncExpr) aggregateRow(r Record) error { func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error {
switch e.getFunctionName() { switch e.getFunctionName() {
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount: case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
return e.evalAggregationNode(r) return e.evalAggregationNode(r, tableAlias)
default: default:
// TODO: traverse arguments and call aggregateRow on // TODO: traverse arguments and call aggregateRow on
// them if they could be an ancestor of an // them if they could be an ancestor of an

View File

@@ -19,6 +19,7 @@ package sql
import ( import (
"errors" "errors"
"fmt" "fmt"
"strings"
) )
// Query analysis - The query is analyzed to determine if it involves // Query analysis - The query is analyzed to determine if it involves
@@ -177,7 +178,7 @@ func (e *PrimaryTerm) analyze(s *Select) (result qProp) {
case e.JPathExpr != nil: case e.JPathExpr != nil:
// Check if the path expression is valid // Check if the path expression is valid
if len(e.JPathExpr.PathExpr) > 0 { if len(e.JPathExpr.PathExpr) > 0 {
if e.JPathExpr.BaseKey.String() != s.From.As { if e.JPathExpr.BaseKey.String() != s.From.As && strings.ToLower(e.JPathExpr.BaseKey.String()) != baseTableName {
result = qProp{err: errInvalidKeypath} result = qProp{err: errInvalidKeypath}
return return
} }

View File

@@ -21,7 +21,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"math" "math"
"strings"
"github.com/bcicen/jstream" "github.com/bcicen/jstream"
"github.com/minio/simdjson-go" "github.com/minio/simdjson-go"
@@ -47,21 +46,21 @@ var (
// of child nodes. The final result row is returned after all rows are // of child nodes. The final result row is returned after all rows are
// processed, and the `getAggregate` function is called. // processed, and the `getAggregate` function is called.
func (e *AliasedExpression) evalNode(r Record) (*Value, error) { func (e *AliasedExpression) evalNode(r Record, tableAlias string) (*Value, error) {
return e.Expression.evalNode(r) return e.Expression.evalNode(r, tableAlias)
} }
func (e *Expression) evalNode(r Record) (*Value, error) { func (e *Expression) evalNode(r Record, tableAlias string) (*Value, error) {
if len(e.And) == 1 { if len(e.And) == 1 {
// In this case, result is not required to be boolean // In this case, result is not required to be boolean
// type. // type.
return e.And[0].evalNode(r) return e.And[0].evalNode(r, tableAlias)
} }
// Compute OR of conditions // Compute OR of conditions
result := false result := false
for _, ex := range e.And { for _, ex := range e.And {
res, err := ex.evalNode(r) res, err := ex.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -74,16 +73,16 @@ func (e *Expression) evalNode(r Record) (*Value, error) {
return FromBool(result), nil return FromBool(result), nil
} }
func (e *AndCondition) evalNode(r Record) (*Value, error) { func (e *AndCondition) evalNode(r Record, tableAlias string) (*Value, error) {
if len(e.Condition) == 1 { if len(e.Condition) == 1 {
// In this case, result does not have to be boolean // In this case, result does not have to be boolean
return e.Condition[0].evalNode(r) return e.Condition[0].evalNode(r, tableAlias)
} }
// Compute AND of conditions // Compute AND of conditions
result := true result := true
for _, ex := range e.Condition { for _, ex := range e.Condition {
res, err := ex.evalNode(r) res, err := ex.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -96,14 +95,14 @@ func (e *AndCondition) evalNode(r Record) (*Value, error) {
return FromBool(result), nil return FromBool(result), nil
} }
func (e *Condition) evalNode(r Record) (*Value, error) { func (e *Condition) evalNode(r Record, tableAlias string) (*Value, error) {
if e.Operand != nil { if e.Operand != nil {
// In this case, result does not have to be boolean // In this case, result does not have to be boolean
return e.Operand.evalNode(r) return e.Operand.evalNode(r, tableAlias)
} }
// Compute NOT of condition // Compute NOT of condition
res, err := e.Not.evalNode(r) res, err := e.Not.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -114,8 +113,8 @@ func (e *Condition) evalNode(r Record) (*Value, error) {
return FromBool(!b), nil return FromBool(!b), nil
} }
func (e *ConditionOperand) evalNode(r Record) (*Value, error) { func (e *ConditionOperand) evalNode(r Record, tableAlias string) (*Value, error) {
opVal, opErr := e.Operand.evalNode(r) opVal, opErr := e.Operand.evalNode(r, tableAlias)
if opErr != nil || e.ConditionRHS == nil { if opErr != nil || e.ConditionRHS == nil {
return opVal, opErr return opVal, opErr
} }
@@ -123,7 +122,7 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
// Need to evaluate the ConditionRHS // Need to evaluate the ConditionRHS
switch { switch {
case e.ConditionRHS.Compare != nil: case e.ConditionRHS.Compare != nil:
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r) cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r, tableAlias)
if cmpRErr != nil { if cmpRErr != nil {
return nil, cmpRErr return nil, cmpRErr
} }
@@ -132,26 +131,26 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
return FromBool(b), err return FromBool(b), err
case e.ConditionRHS.Between != nil: case e.ConditionRHS.Between != nil:
return e.ConditionRHS.Between.evalBetweenNode(r, opVal) return e.ConditionRHS.Between.evalBetweenNode(r, opVal, tableAlias)
case e.ConditionRHS.Like != nil: case e.ConditionRHS.Like != nil:
return e.ConditionRHS.Like.evalLikeNode(r, opVal) return e.ConditionRHS.Like.evalLikeNode(r, opVal, tableAlias)
case e.ConditionRHS.In != nil: case e.ConditionRHS.In != nil:
return e.ConditionRHS.In.evalInNode(r, opVal) return e.ConditionRHS.In.evalInNode(r, opVal, tableAlias)
default: default:
return nil, errInvalidASTNode return nil, errInvalidASTNode
} }
} }
func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) { func (e *Between) evalBetweenNode(r Record, arg *Value, tableAlias string) (*Value, error) {
stVal, stErr := e.Start.evalNode(r) stVal, stErr := e.Start.evalNode(r, tableAlias)
if stErr != nil { if stErr != nil {
return nil, stErr return nil, stErr
} }
endVal, endErr := e.End.evalNode(r) endVal, endErr := e.End.evalNode(r, tableAlias)
if endErr != nil { if endErr != nil {
return nil, endErr return nil, endErr
} }
@@ -174,7 +173,7 @@ func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
return FromBool(result), nil return FromBool(result), nil
} }
func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) { func (e *Like) evalLikeNode(r Record, arg *Value, tableAlias string) (*Value, error) {
inferTypeAsString(arg) inferTypeAsString(arg)
s, ok := arg.ToString() s, ok := arg.ToString()
@@ -183,7 +182,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
return nil, errLikeInvalidInputs(err) return nil, errLikeInvalidInputs(err)
} }
pattern, err1 := e.Pattern.evalNode(r) pattern, err1 := e.Pattern.evalNode(r, tableAlias)
if err1 != nil { if err1 != nil {
return nil, err1 return nil, err1
} }
@@ -199,7 +198,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
escape := runeZero escape := runeZero
if e.EscapeChar != nil { if e.EscapeChar != nil {
escapeVal, err2 := e.EscapeChar.evalNode(r) escapeVal, err2 := e.EscapeChar.evalNode(r, tableAlias)
if err2 != nil { if err2 != nil {
return nil, err2 return nil, err2
} }
@@ -230,14 +229,14 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
return FromBool(matchResult), nil return FromBool(matchResult), nil
} }
func (e *ListExpr) evalNode(r Record) (*Value, error) { func (e *ListExpr) evalNode(r Record, tableAlias string) (*Value, error) {
res := make([]Value, len(e.Elements)) res := make([]Value, len(e.Elements))
if len(e.Elements) == 1 { if len(e.Elements) == 1 {
// If length 1, treat as single value. // If length 1, treat as single value.
return e.Elements[0].evalNode(r) return e.Elements[0].evalNode(r, tableAlias)
} }
for i, elt := range e.Elements { for i, elt := range e.Elements {
v, err := elt.evalNode(r) v, err := elt.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -248,7 +247,7 @@ func (e *ListExpr) evalNode(r Record) (*Value, error) {
const floatCmpTolerance = 0.000001 const floatCmpTolerance = 0.000001
func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) { func (e *In) evalInNode(r Record, lhs *Value, tableAlias string) (*Value, error) {
// Compare two values in terms of in-ness. // Compare two values in terms of in-ness.
var cmp func(a, b Value) bool var cmp func(a, b Value) bool
cmp = func(a, b Value) bool { cmp = func(a, b Value) bool {
@@ -283,7 +282,7 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
var rhs Value var rhs Value
if elt := e.ListExpression; elt != nil { if elt := e.ListExpression; elt != nil {
eltVal, err := elt.evalNode(r) eltVal, err := elt.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -304,8 +303,8 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
return FromBool(cmp(rhs, *lhs)), nil return FromBool(cmp(rhs, *lhs)), nil
} }
func (e *Operand) evalNode(r Record) (*Value, error) { func (e *Operand) evalNode(r Record, tableAlias string) (*Value, error) {
lval, lerr := e.Left.evalNode(r) lval, lerr := e.Left.evalNode(r, tableAlias)
if lerr != nil || len(e.Right) == 0 { if lerr != nil || len(e.Right) == 0 {
return lval, lerr return lval, lerr
} }
@@ -315,7 +314,7 @@ func (e *Operand) evalNode(r Record) (*Value, error) {
// symbols. // symbols.
for _, rightTerm := range e.Right { for _, rightTerm := range e.Right {
op := rightTerm.Op op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r) rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
if rerr != nil { if rerr != nil {
return nil, rerr return nil, rerr
} }
@@ -327,8 +326,8 @@ func (e *Operand) evalNode(r Record) (*Value, error) {
return lval, nil return lval, nil
} }
func (e *MultOp) evalNode(r Record) (*Value, error) { func (e *MultOp) evalNode(r Record, tableAlias string) (*Value, error) {
lval, lerr := e.Left.evalNode(r) lval, lerr := e.Left.evalNode(r, tableAlias)
if lerr != nil || len(e.Right) == 0 { if lerr != nil || len(e.Right) == 0 {
return lval, lerr return lval, lerr
} }
@@ -337,7 +336,7 @@ func (e *MultOp) evalNode(r Record) (*Value, error) {
// AST node is for terms separated by *, / or % symbols. // AST node is for terms separated by *, / or % symbols.
for _, rightTerm := range e.Right { for _, rightTerm := range e.Right {
op := rightTerm.Op op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r) rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
if rerr != nil { if rerr != nil {
return nil, rerr return nil, rerr
} }
@@ -350,12 +349,12 @@ func (e *MultOp) evalNode(r Record) (*Value, error) {
return lval, nil return lval, nil
} }
func (e *UnaryTerm) evalNode(r Record) (*Value, error) { func (e *UnaryTerm) evalNode(r Record, tableAlias string) (*Value, error) {
if e.Negated == nil { if e.Negated == nil {
return e.Primary.evalNode(r) return e.Primary.evalNode(r, tableAlias)
} }
v, err := e.Negated.Term.evalNode(r) v, err := e.Negated.Term.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -368,19 +367,15 @@ func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
return nil, errArithMismatchedTypes return nil, errArithMismatchedTypes
} }
func (e *JSONPath) evalNode(r Record) (*Value, error) { func (e *JSONPath) evalNode(r Record, tableAlias string) (*Value, error) {
// Strip the table name from the keypath. alias := tableAlias
keypath := e.String() if tableAlias == "" {
if strings.Contains(keypath, ".") { alias = baseTableName
ps := strings.SplitN(keypath, ".", 2)
if len(ps) == 2 {
keypath = ps[1]
}
} }
pathExpr := e.StripTableAlias(alias)
_, rawVal := r.Raw() _, rawVal := r.Raw()
switch rowVal := rawVal.(type) { switch rowVal := rawVal.(type) {
case jstream.KVS, simdjson.Object: case jstream.KVS, simdjson.Object:
pathExpr := e.PathExpr
if len(pathExpr) == 0 { if len(pathExpr) == 0 {
pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}} pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}}
} }
@@ -392,7 +387,10 @@ func (e *JSONPath) evalNode(r Record) (*Value, error) {
return jsonToValue(result) return jsonToValue(result)
default: default:
return r.Get(keypath) if pathExpr[len(pathExpr)-1].Key == nil {
return nil, errInvalidKeypath
}
return r.Get(pathExpr[len(pathExpr)-1].Key.keyString())
} }
} }
@@ -447,28 +445,28 @@ func jsonToValue(result interface{}) (*Value, error) {
return nil, fmt.Errorf("Unhandled value type: %T", result) return nil, fmt.Errorf("Unhandled value type: %T", result)
} }
func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) { func (e *PrimaryTerm) evalNode(r Record, tableAlias string) (res *Value, err error) {
switch { switch {
case e.Value != nil: case e.Value != nil:
return e.Value.evalNode(r) return e.Value.evalNode(r)
case e.JPathExpr != nil: case e.JPathExpr != nil:
return e.JPathExpr.evalNode(r) return e.JPathExpr.evalNode(r, tableAlias)
case e.ListExpr != nil: case e.ListExpr != nil:
return e.ListExpr.evalNode(r) return e.ListExpr.evalNode(r, tableAlias)
case e.SubExpression != nil: case e.SubExpression != nil:
return e.SubExpression.evalNode(r) return e.SubExpression.evalNode(r, tableAlias)
case e.FuncCall != nil: case e.FuncCall != nil:
return e.FuncCall.evalNode(r) return e.FuncCall.evalNode(r, tableAlias)
} }
return nil, errInvalidASTNode return nil, errInvalidASTNode
} }
func (e *FuncExpr) evalNode(r Record) (res *Value, err error) { func (e *FuncExpr) evalNode(r Record, tableAlias string) (res *Value, err error) {
switch e.getFunctionName() { switch e.getFunctionName() {
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum: case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
return e.getAggregate() return e.getAggregate()
default: default:
return e.evalSQLFnNode(r) return e.evalSQLFnNode(r, tableAlias)
} }
} }

View File

@@ -84,35 +84,35 @@ func (e *FuncExpr) getFunctionName() FuncName {
// evalSQLFnNode assumes that the FuncExpr is not an aggregation // evalSQLFnNode assumes that the FuncExpr is not an aggregation
// function. // function.
func (e *FuncExpr) evalSQLFnNode(r Record) (res *Value, err error) { func (e *FuncExpr) evalSQLFnNode(r Record, tableAlias string) (res *Value, err error) {
// Handle functions that have phrase arguments // Handle functions that have phrase arguments
switch e.getFunctionName() { switch e.getFunctionName() {
case sqlFnCast: case sqlFnCast:
expr := e.Cast.Expr expr := e.Cast.Expr
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType)) res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType), tableAlias)
return return
case sqlFnSubstring: case sqlFnSubstring:
return handleSQLSubstring(r, e.Substring) return handleSQLSubstring(r, e.Substring, tableAlias)
case sqlFnExtract: case sqlFnExtract:
return handleSQLExtract(r, e.Extract) return handleSQLExtract(r, e.Extract, tableAlias)
case sqlFnTrim: case sqlFnTrim:
return handleSQLTrim(r, e.Trim) return handleSQLTrim(r, e.Trim, tableAlias)
case sqlFnDateAdd: case sqlFnDateAdd:
return handleDateAdd(r, e.DateAdd) return handleDateAdd(r, e.DateAdd, tableAlias)
case sqlFnDateDiff: case sqlFnDateDiff:
return handleDateDiff(r, e.DateDiff) return handleDateDiff(r, e.DateDiff, tableAlias)
} }
// For all simple argument functions, we evaluate the arguments here // For all simple argument functions, we evaluate the arguments here
argVals := make([]*Value, len(e.SFunc.ArgsList)) argVals := make([]*Value, len(e.SFunc.ArgsList))
for i, arg := range e.SFunc.ArgsList { for i, arg := range e.SFunc.ArgsList {
argVals[i], err = arg.evalNode(r) argVals[i], err = arg.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -219,8 +219,8 @@ func upperCase(v *Value) (*Value, error) {
return FromString(strings.ToUpper(s)), nil return FromString(strings.ToUpper(s)), nil
} }
func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) { func handleDateAdd(r Record, d *DateAddFunc, tableAlias string) (*Value, error) {
q, err := d.Quantity.evalNode(r) q, err := d.Quantity.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -230,7 +230,7 @@ func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) {
return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd) return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd)
} }
ts, err := d.Timestamp.evalNode(r) ts, err := d.Timestamp.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -245,8 +245,8 @@ func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) {
return dateAdd(strings.ToUpper(d.DatePart), qty, t) return dateAdd(strings.ToUpper(d.DatePart), qty, t)
} }
func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) { func handleDateDiff(r Record, d *DateDiffFunc, tableAlias string) (*Value, error) {
tval1, err := d.Timestamp1.evalNode(r) tval1, err := d.Timestamp1.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -258,7 +258,7 @@ func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) {
return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff) return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
} }
tval2, err := d.Timestamp2.evalNode(r) tval2, err := d.Timestamp2.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -277,12 +277,12 @@ func handleUTCNow() (*Value, error) {
return FromTimestamp(time.Now().UTC()), nil return FromTimestamp(time.Now().UTC()), nil
} }
func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) { func handleSQLSubstring(r Record, e *SubstringFunc, tableAlias string) (val *Value, err error) {
// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and // Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
// SUBSTRING('abc', 2, 1) are supported. // SUBSTRING('abc', 2, 1) are supported.
// Evaluate the string argument // Evaluate the string argument
v1, err := e.Expr.evalNode(r) v1, err := e.Expr.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -301,7 +301,7 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
} }
// Evaluate the FROM argument // Evaluate the FROM argument
v2, err := arg2.evalNode(r) v2, err := arg2.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -315,7 +315,7 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
length := -1 length := -1
// Evaluate the optional FOR argument // Evaluate the optional FOR argument
if arg3 != nil { if arg3 != nil {
v3, err := arg3.evalNode(r) v3, err := arg3.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -336,11 +336,11 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
return FromString(res), err return FromString(res), err
} }
func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) { func handleSQLTrim(r Record, e *TrimFunc, tableAlias string) (res *Value, err error) {
chars := "" chars := ""
ok := false ok := false
if e.TrimChars != nil { if e.TrimChars != nil {
charsV, cerr := e.TrimChars.evalNode(r) charsV, cerr := e.TrimChars.evalNode(r, tableAlias)
if cerr != nil { if cerr != nil {
return nil, cerr return nil, cerr
} }
@@ -351,7 +351,7 @@ func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
} }
} }
fromV, ferr := e.TrimFrom.evalNode(r) fromV, ferr := e.TrimFrom.evalNode(r, tableAlias)
if ferr != nil { if ferr != nil {
return nil, ferr return nil, ferr
} }
@@ -368,8 +368,8 @@ func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
return FromString(result), nil return FromString(result), nil
} }
func handleSQLExtract(r Record, e *ExtractFunc) (res *Value, err error) { func handleSQLExtract(r Record, e *ExtractFunc, tableAlias string) (res *Value, err error) {
timeVal, verr := e.From.evalNode(r) timeVal, verr := e.From.evalNode(r, tableAlias)
if verr != nil { if verr != nil {
return nil, verr return nil, verr
} }
@@ -406,8 +406,8 @@ const (
castTimestamp = "TIMESTAMP" castTimestamp = "TIMESTAMP"
) )
func (e *Expression) castTo(r Record, castType string) (res *Value, err error) { func (e *Expression) castTo(r Record, castType string, tableAlias string) (res *Value, err error) {
v, err := e.evalNode(r) v, err := e.evalNode(r, tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -119,7 +119,9 @@ type JSONPath struct {
PathExpr []*JSONPathElement `parser:"(@@)*"` PathExpr []*JSONPathElement `parser:"(@@)*"`
// Cached values: // Cached values:
pathString string pathString string
strippedTableAlias string
strippedPathExpr []*JSONPathElement
} }
// AliasedExpression is an expression that can be optionally named // AliasedExpression is an expression that can be optionally named

View File

@@ -46,6 +46,9 @@ type SelectStatement struct {
// Count of rows that have been output. // Count of rows that have been output.
outputCount int64 outputCount int64
// Table alias
tableAlias string
} }
// ParseSelectStatement - parses a select query from the given string // ParseSelectStatement - parses a select query from the given string
@@ -107,6 +110,9 @@ func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
if err != nil { if err != nil {
err = errQueryAnalysisFailure(err) err = errQueryAnalysisFailure(err)
} }
// Set table alias
stmt.tableAlias = selectAST.From.As
return return
} }
@@ -226,7 +232,7 @@ func (e *SelectStatement) IsAggregated() bool {
// records have been processed. Applies only to aggregation queries. // records have been processed. Applies only to aggregation queries.
func (e *SelectStatement) AggregateResult(output Record) error { func (e *SelectStatement) AggregateResult(output Record) error {
for i, expr := range e.selectAST.Expression.Expressions { for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(nil) v, err := expr.evalNode(nil, e.tableAlias)
if err != nil { if err != nil {
return err return err
} }
@@ -246,7 +252,7 @@ func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) {
if e.selectAST.Where == nil { if e.selectAST.Where == nil {
return true, nil return true, nil
} }
value, err := e.selectAST.Where.evalNode(input) value, err := e.selectAST.Where.evalNode(input, e.tableAlias)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -272,7 +278,7 @@ func (e *SelectStatement) AggregateRow(input Record) error {
} }
for _, expr := range e.selectAST.Expression.Expressions { for _, expr := range e.selectAST.Expression.Expressions {
err := expr.aggregateRow(input) err := expr.aggregateRow(input, e.tableAlias)
if err != nil { if err != nil {
return err return err
} }
@@ -302,7 +308,7 @@ func (e *SelectStatement) Eval(input, output Record) (Record, error) {
} }
for i, expr := range e.selectAST.Expression.Expressions { for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(input) v, err := expr.evalNode(input, e.tableAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -36,6 +36,27 @@ func (e *JSONPath) String() string {
return e.pathString return e.pathString
} }
// StripTableAlias removes a table alias from the path. The result is also
// cached for repeated lookups during SQL query evaluation.
func (e *JSONPath) StripTableAlias(tableAlias string) []*JSONPathElement {
if e.strippedTableAlias == tableAlias {
return e.strippedPathExpr
}
hasTableAlias := e.BaseKey.String() == tableAlias || strings.ToLower(e.BaseKey.String()) == baseTableName
var pathExpr []*JSONPathElement
if hasTableAlias {
pathExpr = e.PathExpr
} else {
pathExpr = make([]*JSONPathElement, len(e.PathExpr)+1)
pathExpr[0] = &JSONPathElement{Key: &ObjectKey{ID: e.BaseKey}}
copy(pathExpr[1:], e.PathExpr)
}
e.strippedTableAlias = tableAlias
e.strippedPathExpr = pathExpr
return e.strippedPathExpr
}
func (e *JSONPathElement) String() string { func (e *JSONPathElement) String() string {
switch { switch {
case e.Key != nil: case e.Key != nil: