From 62c117421ab042a686fd72e419bf272c9265b814 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Wed, 3 Feb 2021 08:19:34 -0500 Subject: [PATCH] internal/kubeclient: fix not found test and request body closing bug - I realized that the hardcoded fakekubeapi 404 not found response was invalid, so we were getting a default error message. I fixed it so the tests follow a higher fidelity code path. - I caved and added a test for making sure the request body was always closed, and believe it or not, we were double closing a body. I don't *think* this will matter in production, since client-go will pass us ioutil.NopReader()'s, but at least we know now. Signed-off-by: Andrew Keesler --- internal/kubeclient/kubeclient.go | 8 +-- internal/kubeclient/kubeclient_test.go | 70 +++++++++++++++----- internal/kubeclient/option.go | 19 +++++- internal/kubeclient/roundtrip.go | 13 ++-- internal/testutil/fakekubeapi/fakekubeapi.go | 56 +++++++--------- 5 files changed, 104 insertions(+), 62 deletions(-) diff --git a/internal/kubeclient/kubeclient.go b/internal/kubeclient/kubeclient.go index 5a04e1f82..ebf771ab3 100644 --- a/internal/kubeclient/kubeclient.go +++ b/internal/kubeclient/kubeclient.go @@ -51,13 +51,13 @@ func New(opts ...Option) (*Client, error) { protoKubeConfig := createProtoKubeConfig(c.config) // Connect to the core Kubernetes API. - k8sClient, err := kubernetes.NewForConfig(configWithWrapper(protoKubeConfig, kubescheme.Scheme, kubescheme.Codecs, c.middlewares)) + k8sClient, err := kubernetes.NewForConfig(configWithWrapper(protoKubeConfig, kubescheme.Scheme, kubescheme.Codecs, c.middlewares, c.transportWrapper)) if err != nil { return nil, fmt.Errorf("could not initialize Kubernetes client: %w", err) } // Connect to the Kubernetes aggregation API. - aggregatorClient, err := aggregatorclient.NewForConfig(configWithWrapper(protoKubeConfig, aggregatorclientscheme.Scheme, aggregatorclientscheme.Codecs, c.middlewares)) + aggregatorClient, err := aggregatorclient.NewForConfig(configWithWrapper(protoKubeConfig, aggregatorclientscheme.Scheme, aggregatorclientscheme.Codecs, c.middlewares, c.transportWrapper)) if err != nil { return nil, fmt.Errorf("could not initialize aggregation client: %w", err) } @@ -65,7 +65,7 @@ func New(opts ...Option) (*Client, error) { // Connect to the pinniped concierge API. // We cannot use protobuf encoding here because we are using CRDs // (for which protobuf encoding is not yet supported). - pinnipedConciergeClient, err := pinnipedconciergeclientset.NewForConfig(configWithWrapper(jsonKubeConfig, pinnipedconciergeclientsetscheme.Scheme, pinnipedconciergeclientsetscheme.Codecs, c.middlewares)) + pinnipedConciergeClient, err := pinnipedconciergeclientset.NewForConfig(configWithWrapper(jsonKubeConfig, pinnipedconciergeclientsetscheme.Scheme, pinnipedconciergeclientsetscheme.Codecs, c.middlewares, c.transportWrapper)) if err != nil { return nil, fmt.Errorf("could not initialize pinniped client: %w", err) } @@ -73,7 +73,7 @@ func New(opts ...Option) (*Client, error) { // Connect to the pinniped supervisor API. // We cannot use protobuf encoding here because we are using CRDs // (for which protobuf encoding is not yet supported). - pinnipedSupervisorClient, err := pinnipedsupervisorclientset.NewForConfig(configWithWrapper(jsonKubeConfig, pinnipedsupervisorclientsetscheme.Scheme, pinnipedsupervisorclientsetscheme.Codecs, c.middlewares)) + pinnipedSupervisorClient, err := pinnipedsupervisorclientset.NewForConfig(configWithWrapper(jsonKubeConfig, pinnipedsupervisorclientsetscheme.Scheme, pinnipedsupervisorclientsetscheme.Codecs, c.middlewares, c.transportWrapper)) if err != nil { return nil, fmt.Errorf("could not initialize pinniped client: %w", err) } diff --git a/internal/kubeclient/kubeclient_test.go b/internal/kubeclient/kubeclient_test.go index a52699104..2bb4aaeeb 100644 --- a/internal/kubeclient/kubeclient_test.go +++ b/internal/kubeclient/kubeclient_test.go @@ -66,12 +66,6 @@ var ( middlewareLabels = map[string]string{"some-label": "thing 2"} ) -// TestKubeclient tests a subset of kubeclient functionality (from the public interface down). We -// intend for the following list of things to be tested with the integration tests: -// list (running in every informer cache) -// watch (running in every informer cache) -// discovery -// api errors func TestKubeclient(t *testing.T) { // plog.ValidateAndSetLogLevelGlobally(plog.LevelDebug) // uncomment me to get some more debug logs @@ -109,7 +103,7 @@ func TestKubeclient(t *testing.T) { CoreV1(). Pods(pod.Namespace). Get(context.Background(), "this-pod-does-not-exist", metav1.GetOptions{}) - require.EqualError(t, err, "the server could not find the requested resource (get pods this-pod-does-not-exist)") + require.EqualError(t, err, `couldn't find object for path "/api/v1/namespaces/good-namespace/pods/this-pod-does-not-exist"`) // update goodPodWithAnnotationsAndLabelsAndClusterName := with(goodPod, annotations(), labels(), clusterName()).(*corev1.Pod) @@ -546,16 +540,15 @@ func TestKubeclient(t *testing.T) { test.editRestConfig(t, restConfig) } - // our rt chain is: - // kubeclient -> wantCloseResp -> http.DefaultTransport -> wantCloseResp -> kubeclient - restConfig.Wrap(wantCloseRespWrapper(t)) - var middlewares []*spyMiddleware if test.middlewares != nil { middlewares = test.middlewares(t) } - opts := []Option{WithConfig(restConfig)} + // our rt chain is: + // wantCloseReq -> kubeclient -> wantCloseResp -> http.DefaultTransport -> wantCloseResp -> kubeclient -> wantCloseReq + restConfig.Wrap(wantCloseRespWrapper(t)) + opts := []Option{WithConfig(restConfig), WithTransportWrapper(wantCloseReqWrapper(t))} for _, middleware := range middlewares { opts = append(opts, WithMiddleware(middleware)) } @@ -675,11 +668,13 @@ func newSimpleMiddleware(t *testing.T, hasMutateReqFunc, mutatedReq, hasMutateRe type wantCloser struct { io.ReadCloser closeCount int + closeCalls []string couldReadBytesJustBeforeClosing bool } func (wc *wantCloser) Close() error { wc.closeCount++ + wc.closeCalls = append(wc.closeCalls, getCaller()) n, _ := wc.ReadCloser.Read([]byte{0}) if n > 0 { // there were still bytes left to be read @@ -688,14 +683,53 @@ func (wc *wantCloser) Close() error { return wc.ReadCloser.Close() } -// wantCloseRespWrapper returns a transport.WrapperFunc that validates that the http.Response -// returned by the underlying http.RoundTripper is closed properly. -func wantCloseRespWrapper(t *testing.T) transport.WrapperFunc { - _, file, line, ok := runtime.Caller(1) +func getCaller() string { + _, file, line, ok := runtime.Caller(2) if !ok { file = "???" line = 0 } + return fmt.Sprintf("%s:%d", file, line) +} + +// wantCloseReqWrapper returns a transport.WrapperFunc that validates that the http.Request +// passed to the underlying http.RoundTripper is closed properly. +func wantCloseReqWrapper(t *testing.T) transport.WrapperFunc { + caller := getCaller() + return func(rt http.RoundTripper) http.RoundTripper { + return roundTripperFunc(func(req *http.Request) (bool, *http.Response, error) { + if req.Body != nil { + wc := &wantCloser{ReadCloser: req.Body} + t.Cleanup(func() { + require.Equalf(t, wc.closeCount, 1, "did not close req body expected number of times at %s for req %#v; actual calls = %s", caller, req, wc.closeCalls) + }) + req.Body = wc + } + + if req.GetBody != nil { + originalBodyCopy, originalErr := req.GetBody() + req.GetBody = func() (io.ReadCloser, error) { + if originalErr != nil { + return nil, originalErr + } + wc := &wantCloser{ReadCloser: originalBodyCopy} + t.Cleanup(func() { + require.Equalf(t, wc.closeCount, 1, "did not close req body copy expected number of times at %s for req %#v; actual calls = %s", caller, req, wc.closeCalls) + }) + return wc, nil + } + } + + resp, err := rt.RoundTrip(req) + return false, resp, err + }) + } +} + +// wantCloseRespWrapper returns a transport.WrapperFunc that validates that the http.Response +// returned by the underlying http.RoundTripper is closed properly. +func wantCloseRespWrapper(t *testing.T) transport.WrapperFunc { + caller := getCaller() return func(rt http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (bool, *http.Response, error) { resp, err := rt.RoundTrip(req) @@ -705,8 +739,8 @@ func wantCloseRespWrapper(t *testing.T) transport.WrapperFunc { } wc := &wantCloser{ReadCloser: resp.Body} t.Cleanup(func() { - require.False(t, wc.couldReadBytesJustBeforeClosing, "did not consume all response body bytes before closing %s:%d", file, line) - require.Equalf(t, wc.closeCount, 1, "did not close resp body at %s:%d", file, line) + require.False(t, wc.couldReadBytesJustBeforeClosing, "did not consume all response body bytes before closing %s", caller) + require.Equalf(t, wc.closeCount, 1, "did not close resp body expected number of times at %s for req %#v; actual calls = %s", caller, req, wc.closeCalls) }) resp.Body = wc return false, resp, err diff --git a/internal/kubeclient/option.go b/internal/kubeclient/option.go index 789120fba..7c10bda98 100644 --- a/internal/kubeclient/option.go +++ b/internal/kubeclient/option.go @@ -3,13 +3,17 @@ package kubeclient -import restclient "k8s.io/client-go/rest" +import ( + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/transport" +) type Option func(*clientConfig) type clientConfig struct { - config *restclient.Config - middlewares []Middleware + config *restclient.Config + middlewares []Middleware + transportWrapper transport.WrapperFunc } func WithConfig(config *restclient.Config) Option { @@ -27,3 +31,12 @@ func WithMiddleware(middleware Middleware) Option { c.middlewares = append(c.middlewares, middleware) } } + +// WithTransportWrapper will wrap the client-go http.RoundTripper chain *after* the middleware +// wrapper is applied. I.e., this wrapper has the opportunity to supply an http.RoundTripper that +// runs first in the client-go http.RoundTripper chain. +func WithTransportWrapper(wrapper transport.WrapperFunc) Option { + return func(c *clientConfig) { + c.transportWrapper = wrapper + } +} diff --git a/internal/kubeclient/roundtrip.go b/internal/kubeclient/roundtrip.go index 5caf2a39b..3d6029dc5 100644 --- a/internal/kubeclient/roundtrip.go +++ b/internal/kubeclient/roundtrip.go @@ -23,7 +23,7 @@ import ( "go.pinniped.dev/internal/plog" ) -func configWithWrapper(config *restclient.Config, scheme *runtime.Scheme, negotiatedSerializer runtime.NegotiatedSerializer, middlewares []Middleware) *restclient.Config { +func configWithWrapper(config *restclient.Config, scheme *runtime.Scheme, negotiatedSerializer runtime.NegotiatedSerializer, middlewares []Middleware, wrapper transport.WrapperFunc) *restclient.Config { hostURL, apiPathPrefix, err := getHostAndAPIPathPrefix(config) if err != nil { plog.DebugErr("invalid rest config", err) @@ -49,6 +49,9 @@ func configWithWrapper(config *restclient.Config, scheme *runtime.Scheme, negoti cc := restclient.CopyConfig(config) cc.Wrap(f) + if wrapper != nil { + cc.Wrap(wrapper) + } return cc } @@ -173,20 +176,20 @@ func handleOtherVerbs( resp, err := rt.RoundTrip(newReq) if err != nil { - return true, nil, fmt.Errorf("middleware request for %#v failed: %w", middlewareReq, err) + return false, nil, fmt.Errorf("middleware request for %#v failed: %w", middlewareReq, err) } switch v { case VerbDelete, VerbDeleteCollection: - return true, resp, nil // we do not need to fix the response on delete + return false, resp, nil // we do not need to fix the response on delete case VerbWatch: resp, err := handleWatchResponseNewGVK(config, negotiatedSerializer, resp, middlewareReq, result) - return true, resp, err + return false, resp, err default: // VerbGet, VerbList, VerbPatch resp, err := handleResponseNewGVK(config, negotiatedSerializer, resp, middlewareReq, result) - return true, resp, err + return false, resp, err } } diff --git a/internal/testutil/fakekubeapi/fakekubeapi.go b/internal/testutil/fakekubeapi/fakekubeapi.go index e3096e5e4..66450f581 100644 --- a/internal/testutil/fakekubeapi/fakekubeapi.go +++ b/internal/testutil/fakekubeapi/fakekubeapi.go @@ -19,6 +19,7 @@ package fakekubeapi import ( "encoding/pem" + "fmt" "io/ioutil" "mime" "net/http" @@ -39,20 +40,6 @@ import ( "go.pinniped.dev/internal/multierror" ) -// Unlike the standard httperr.New(), this one does not prepend error messages with any prefix. -type plainHTTPErr struct { - code int - msg string -} - -func (e plainHTTPErr) Error() string { - return e.msg -} - -func (e plainHTTPErr) Respond(w http.ResponseWriter) { - http.Error(w, e.msg, e.code) -} - // Start starts an httptest.Server (with TLS) that pretends to be a Kube API server. // // The server uses the provided resources map to store API Object's. The map should be from API path @@ -62,9 +49,9 @@ func (e plainHTTPErr) Respond(w http.ResponseWriter) { // to the server. // // Note! Only these following verbs are (partially) supported: create, get, update, delete. -func Start(t *testing.T, resources map[string]metav1.Object) (*httptest.Server, *restclient.Config) { +func Start(t *testing.T, resources map[string]runtime.Object) (*httptest.Server, *restclient.Config) { if resources == nil { - resources = make(map[string]metav1.Object) + resources = make(map[string]runtime.Object) } server := httptest.NewTLSServer(httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (err error) { @@ -78,12 +65,8 @@ func Start(t *testing.T, resources map[string]metav1.Object) (*httptest.Server, return err } - if r.Method != http.MethodDelete && obj == nil { - return &plainHTTPErr{ - code: http.StatusNotFound, - // This is representative of a real Kube 404 message body. - msg: `{"kind":"Status","apiVersion":"v1","metadata":{},"status":"Failure","message":"not found","reason":"NotFound","details":{"name":"not-found","kind":"pods"},"code":404}`, - } + if obj == nil { + obj = newNotFoundStatus(r.URL.Path) } if err := encodeObj(w, r, obj); err != nil { @@ -101,7 +84,7 @@ func Start(t *testing.T, resources map[string]metav1.Object) (*httptest.Server, return server, restConfig } -func decodeObj(r *http.Request) (metav1.Object, error) { +func decodeObj(r *http.Request) (runtime.Object, error) { switch r.Method { case http.MethodPut, http.MethodPost: default: @@ -123,7 +106,7 @@ func decodeObj(r *http.Request) (metav1.Object, error) { return nil, httperr.Wrap(http.StatusInternalServerError, "read body", err) } - var obj metav1.Object + var obj runtime.Object multiErr := multierror.New() codecsThatWeUseInOurCode := []runtime.NegotiatedSerializer{ kubescheme.Codecs, @@ -145,7 +128,7 @@ func tryDecodeObj( mediaType string, body []byte, negotiatedSerializer runtime.NegotiatedSerializer, -) (metav1.Object, error) { +) (runtime.Object, error) { serializerInfo, ok := runtime.SerializerInfoForMediaType(negotiatedSerializer.SupportedMediaTypes(), mediaType) if !ok { return nil, httperr.Newf(http.StatusInternalServerError, "unable to find serialier with content-type %s", mediaType) @@ -156,19 +139,17 @@ func tryDecodeObj( return nil, httperr.Wrap(http.StatusInternalServerError, "decode obj", err) } - return obj.(metav1.Object), nil + return obj, nil } -func handleObj(r *http.Request, obj metav1.Object, resources map[string]metav1.Object) (metav1.Object, error) { +func handleObj(r *http.Request, obj runtime.Object, resources map[string]runtime.Object) (runtime.Object, error) { switch r.Method { case http.MethodGet: obj = resources[r.URL.Path] case http.MethodPost, http.MethodPut: - resources[path.Join(r.URL.Path, obj.GetName())] = obj + resources[path.Join(r.URL.Path, obj.(metav1.Object).GetName())] = obj case http.MethodDelete: - if _, ok := resources[r.URL.Path]; !ok { - return nil, httperr.Newf(http.StatusNotFound, "no resource with path %q", r.URL.Path) - } + obj = resources[r.URL.Path] delete(resources, r.URL.Path) default: return nil, httperr.New(http.StatusMethodNotAllowed, "check source code for methods supported") @@ -177,7 +158,18 @@ func handleObj(r *http.Request, obj metav1.Object, resources map[string]metav1.O return obj, nil } -func encodeObj(w http.ResponseWriter, r *http.Request, obj metav1.Object) error { +func newNotFoundStatus(path string) runtime.Object { + status := &metav1.Status{ + Status: metav1.StatusFailure, + Message: fmt.Sprintf("couldn't find object for path %q", path), + Reason: metav1.StatusReasonNotFound, + Code: http.StatusNotFound, + } + status.APIVersion, status.Kind = metav1.SchemeGroupVersion.WithKind("Status").ToAPIVersionAndKind() + return status +} + +func encodeObj(w http.ResponseWriter, r *http.Request, obj runtime.Object) error { if r.Method == http.MethodDelete { return nil }