From d50049a0cff703b214f6cbb994c60dfcd2cbd623 Mon Sep 17 00:00:00 2001 From: William Banfield Date: Mon, 9 May 2022 17:37:46 -0400 Subject: [PATCH] extract metrics package name --- scripts/metricsgen/metricsgen.go | 45 ++++++++++++++++++++------- scripts/metricsgen/metricsgen_test.go | 39 +++++++++++++++++++++++ 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/scripts/metricsgen/metricsgen.go b/scripts/metricsgen/metricsgen.go index 4b122f975..f10418ea8 100644 --- a/scripts/metricsgen/metricsgen.go +++ b/scripts/metricsgen/metricsgen.go @@ -15,9 +15,11 @@ import ( "io/fs" "log" "os" + "path" "path/filepath" "reflect" "regexp" + "strconv" "strings" "text/template" ) @@ -36,6 +38,8 @@ Options: } } +const metricsPackageName = "github.com/go-kit/kit/metrics" + const ( metricNameTag = "metricsgen_name" labelsTag = "metricsgen_labels" @@ -173,12 +177,12 @@ func ParseMetricsDir(dir string, structName string) (TemplateData, error) { Package: pkgName, } // Grab the metrics struct - m, err := fetchMetricsStruct(pkg.Files, structName) + m, mPkgName, err := fetchMetricsStruct(pkg.Files, structName) if err != nil { return TemplateData{}, err } for _, f := range m.Fields.List { - if !isMetric(f.Type) { + if !isMetric(f.Type, mPkgName) { continue } pmf := parseMetricField(f) @@ -208,12 +212,15 @@ func GenerateMetricsFile(w io.Writer, td TemplateData) error { return nil } -func fetchMetricsStruct(files map[string]*ast.File, structName string) (*ast.StructType, error) { +func fetchMetricsStruct(files map[string]*ast.File, structName string) (*ast.StructType, string, error) { var ( - err error - st *ast.StructType + st *ast.StructType ) for _, file := range files { + mPkgName, err := extractMetricsPackageName(file.Imports) + if err != nil { + return nil, "", fmt.Errorf("unable to determine metrics package name: %v", err) + } if !ast.FilterFile(file, func(name string) bool { return name == structName }) { @@ -235,13 +242,13 @@ func fetchMetricsStruct(files map[string]*ast.File, structName string) (*ast.Str } }) if err != nil { - return nil, err + return nil, "", err } if st != nil { - return st, nil + return st, mPkgName, nil } } - return nil, fmt.Errorf("target struct %q not found in dir", structName) + return nil, "", fmt.Errorf("target struct %q not found in dir", structName) } func parseMetricField(f *ast.Field) ParsedMetricField { @@ -265,11 +272,11 @@ func parseMetricField(f *ast.Field) ParsedMetricField { } func extractTypeName(e ast.Expr) string { - return strings.TrimPrefix(types.ExprString(e), "metrics.") + return strings.TrimPrefix(path.Ext(types.ExprString(e)), ".") } -func isMetric(e ast.Expr) bool { - return strings.Contains(types.ExprString(e), "metrics.") +func isMetric(e ast.Expr, mPkgName string) bool { + return strings.Contains(types.ExprString(e), fmt.Sprintf("%s.", mPkgName)) } func extractLabels(bl *ast.BasicLit) string { @@ -306,6 +313,22 @@ func extractHistogramOptions(tag *ast.BasicLit) HistogramOpts { return h } +func extractMetricsPackageName(imports []*ast.ImportSpec) (string, error) { + for _, i := range imports { + u, err := strconv.Unquote(i.Path.Value) + if err != nil { + return "", err + } + if u == metricsPackageName { + if i.Name != nil { + return i.Name.Name, nil + } + return path.Base(u), nil + } + } + return "", nil +} + var capitalChange = regexp.MustCompile("([a-z0-9])([A-Z])") func toSnakeCase(str string) string { diff --git a/scripts/metricsgen/metricsgen_test.go b/scripts/metricsgen/metricsgen_test.go index bffaeabb7..fadd0ec65 100644 --- a/scripts/metricsgen/metricsgen_test.go +++ b/scripts/metricsgen/metricsgen_test.go @@ -5,6 +5,7 @@ import ( "fmt" "go/parser" "go/token" + "io" "io/ioutil" "os" "path" @@ -201,3 +202,41 @@ func TestParseMetricsStruct(t *testing.T) { }) } } + +func TestParseAliasedMetric(t *testing.T) { + aliasedData := ` + package mypkg + + import( + mymetrics "github.com/go-kit/kit/metrics" + ) + type Metrics struct { + m mymetrics.Gauge + } + ` + dir, err := os.MkdirTemp(os.TempDir(), "metricsdir") + if err != nil { + t.Fatalf("unable to create directory: %v", err) + } + defer os.Remove(dir) + f, err := os.Create(filepath.Join(dir, "metrics.go")) + if err != nil { + t.Fatalf("unable to open file: %v", err) + } + _, err = io.WriteString(f, aliasedData) + td, err := metricsgen.ParseMetricsDir(dir, "Metrics") + require.NoError(t, err) + + expected := + metricsgen.TemplateData{ + Package: "mypkg", + ParsedMetrics: []metricsgen.ParsedMetricField{ + { + TypeName: "Gauge", + FieldName: "m", + MetricName: "m", + }, + }, + } + require.Equal(t, expected, td) +}