Skip to content

Commit a69d9a2

Browse files
committed
internal/analysisinternal: AddImport helper for inserting imports
This CL defines a new helper for inserting import declarations, as needed, when a refactoring introduces a reference to an imported symbol. Also, a test. Change-Id: Icba17e6f76e67d2dad8f68b312db7111f4df817a Reviewed-on: https://go-review.googlesource.com/c/tools/+/592277 Reviewed-by: Robert Findley <[email protected]> Reviewed-by: Tim King <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent db99e8a commit a69d9a2

File tree

2 files changed

+312
-0
lines changed

2 files changed

+312
-0
lines changed
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package analysisinternal_test
6+
7+
import (
8+
"fmt"
9+
"go/ast"
10+
"go/importer"
11+
"go/parser"
12+
"go/token"
13+
"go/types"
14+
"runtime"
15+
"strings"
16+
"testing"
17+
18+
"github.com/google/go-cmp/cmp"
19+
"golang.org/x/tools/internal/analysisinternal"
20+
)
21+
22+
func TestAddImport(t *testing.T) {
23+
descr := func(s string) string {
24+
if _, _, line, ok := runtime.Caller(1); ok {
25+
return fmt.Sprintf("L%d %s", line, s)
26+
}
27+
panic("runtime.Caller failed")
28+
}
29+
30+
// Each test case contains a «name pkgpath»
31+
// section to be replaced with a reference
32+
// to a valid import of pkgpath,
33+
// ideally of the specified name.
34+
for _, test := range []struct {
35+
descr, src, want string
36+
}{
37+
{
38+
descr: descr("simple add import"),
39+
src: `package a
40+
func _() {
41+
«fmt fmt»
42+
}`,
43+
want: `package a
44+
import "fmt"
45+
46+
func _() {
47+
fmt
48+
}`,
49+
},
50+
{
51+
descr: descr("existing import"),
52+
src: `package a
53+
54+
import "fmt"
55+
56+
func _(fmt.Stringer) {
57+
«fmt fmt»
58+
}`,
59+
want: `package a
60+
61+
import "fmt"
62+
63+
func _(fmt.Stringer) {
64+
fmt
65+
}`,
66+
},
67+
{
68+
descr: descr("existing blank import"),
69+
src: `package a
70+
71+
import _ "fmt"
72+
73+
func _() {
74+
«fmt fmt»
75+
}`,
76+
want: `package a
77+
78+
import "fmt"
79+
80+
import _ "fmt"
81+
82+
func _() {
83+
fmt
84+
}`,
85+
},
86+
{
87+
descr: descr("existing renaming import"),
88+
src: `package a
89+
90+
import fmtpkg "fmt"
91+
92+
var fmt int
93+
94+
func _(fmtpkg.Stringer) {
95+
«fmt fmt»
96+
}`,
97+
want: `package a
98+
99+
import fmtpkg "fmt"
100+
101+
var fmt int
102+
103+
func _(fmtpkg.Stringer) {
104+
fmtpkg
105+
}`,
106+
},
107+
{
108+
descr: descr("existing import is shadowed"),
109+
src: `package a
110+
111+
import "fmt"
112+
113+
var _ fmt.Stringer
114+
115+
func _(fmt int) {
116+
«fmt fmt»
117+
}`,
118+
want: `package a
119+
120+
import fmt0 "fmt"
121+
122+
import "fmt"
123+
124+
var _ fmt.Stringer
125+
126+
func _(fmt int) {
127+
fmt0
128+
}`,
129+
},
130+
{
131+
descr: descr("preferred name is shadowed"),
132+
src: `package a
133+
134+
import "fmt"
135+
136+
func _(fmt fmt.Stringer) {
137+
«fmt fmt»
138+
}`,
139+
want: `package a
140+
141+
import fmt0 "fmt"
142+
143+
import "fmt"
144+
145+
func _(fmt fmt.Stringer) {
146+
fmt0
147+
}`,
148+
},
149+
{
150+
descr: descr("import inserted before doc comments"),
151+
src: `package a
152+
153+
// hello
154+
import ()
155+
156+
// world
157+
func _() {
158+
«fmt fmt»
159+
}`,
160+
want: `package a
161+
162+
import "fmt"
163+
164+
// hello
165+
import ()
166+
167+
// world
168+
func _() {
169+
fmt
170+
}`,
171+
},
172+
{
173+
descr: descr("arbitrary preferred name => renaming import"),
174+
src: `package a
175+
176+
func _() {
177+
«foo encoding/json»
178+
}`,
179+
want: `package a
180+
181+
import foo "encoding/json"
182+
183+
func _() {
184+
foo
185+
}`,
186+
},
187+
} {
188+
t.Run(test.descr, func(t *testing.T) {
189+
// splice marker
190+
before, mid, ok1 := strings.Cut(test.src, "«")
191+
mid, after, ok2 := strings.Cut(mid, "»")
192+
if !ok1 || !ok2 {
193+
t.Fatal("no «name path» marker")
194+
}
195+
src := before + "/*!*/" + after
196+
name, path, _ := strings.Cut(mid, " ")
197+
198+
// parse
199+
fset := token.NewFileSet()
200+
f, err := parser.ParseFile(fset, "a.go", src, parser.ParseComments)
201+
if err != nil {
202+
t.Log(err)
203+
}
204+
pos := fset.File(f.Pos()).Pos(len(before))
205+
206+
// type-check
207+
info := &types.Info{
208+
Types: make(map[ast.Expr]types.TypeAndValue),
209+
Scopes: make(map[ast.Node]*types.Scope),
210+
Defs: make(map[*ast.Ident]types.Object),
211+
Implicits: make(map[ast.Node]types.Object),
212+
}
213+
conf := &types.Config{
214+
Error: func(err error) { t.Log(err) },
215+
Importer: importer.Default(),
216+
}
217+
conf.Check(f.Name.Name, fset, []*ast.File{f}, info)
218+
219+
// add import
220+
name, edit := analysisinternal.AddImport(info, f, pos, path, name)
221+
222+
// apply patch
223+
start := fset.Position(edit.Pos)
224+
end := fset.Position(edit.End)
225+
output := src[:start.Offset] + string(edit.NewText) + src[end.Offset:]
226+
output = strings.ReplaceAll(output, "/*!*/", name)
227+
if output != test.want {
228+
t.Errorf("\n--got--\n%s\n--want--\n%s\n--diff--\n%s",
229+
output, test.want, cmp.Diff(test.want, output))
230+
}
231+
})
232+
}
233+
}

internal/analysisinternal/analysis.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"go/token"
1414
"go/types"
1515
"os"
16+
pathpkg "path"
1617
"strconv"
1718

1819
"golang.org/x/tools/go/analysis"
@@ -432,3 +433,81 @@ func slicesContains[S ~[]E, E comparable](slice S, x E) bool {
432433
}
433434
return false
434435
}
436+
437+
// AddImport checks whether this file already imports pkgpath and
438+
// that import is in scope at pos. If so, it returns the name under
439+
// which it was imported and a zero edit. Otherwise, it adds a new
440+
// import of pkgpath, using a name derived from the preferred name,
441+
// and returns the chosen name along with the edit for the new import.
442+
//
443+
// It does not mutate its arguments.
444+
func AddImport(info *types.Info, file *ast.File, pos token.Pos, pkgpath, preferredName string) (name string, newImport analysis.TextEdit) {
445+
// Find innermost enclosing lexical block.
446+
scope := info.Scopes[file].Innermost(pos)
447+
if scope == nil {
448+
panic("no enclosing lexical block")
449+
}
450+
451+
// Is there an existing import of this package?
452+
// If so, are we in its scope? (not shadowed)
453+
for _, spec := range file.Imports {
454+
pkgname, ok := importedPkgName(info, spec)
455+
if ok && pkgname.Imported().Path() == pkgpath {
456+
if _, obj := scope.LookupParent(pkgname.Name(), pos); obj == pkgname {
457+
return pkgname.Name(), analysis.TextEdit{}
458+
}
459+
}
460+
}
461+
462+
// We must add a new import.
463+
// Ensure we have a fresh name.
464+
newName := preferredName
465+
for i := 0; ; i++ {
466+
if _, obj := scope.LookupParent(newName, pos); obj == nil {
467+
break // fresh
468+
}
469+
newName = fmt.Sprintf("%s%d", preferredName, i)
470+
}
471+
472+
// For now, keep it real simple: create a new import
473+
// declaration before the first existing declaration (which
474+
// must exist), including its comments, and let goimports tidy it up.
475+
//
476+
// Use a renaming import whenever the preferred name is not
477+
// available, or the chosen name does not match the last
478+
// segment of its path.
479+
newText := fmt.Sprintf("import %q\n\n", pkgpath)
480+
if newName != preferredName || newName != pathpkg.Base(pkgpath) {
481+
newText = fmt.Sprintf("import %s %q\n\n", newName, pkgpath)
482+
}
483+
decl0 := file.Decls[0]
484+
var before ast.Node = decl0
485+
switch decl0 := decl0.(type) {
486+
case *ast.GenDecl:
487+
if decl0.Doc != nil {
488+
before = decl0.Doc
489+
}
490+
case *ast.FuncDecl:
491+
if decl0.Doc != nil {
492+
before = decl0.Doc
493+
}
494+
}
495+
return newName, analysis.TextEdit{
496+
Pos: before.Pos(),
497+
End: before.Pos(),
498+
NewText: []byte(newText),
499+
}
500+
}
501+
502+
// importedPkgName returns the PkgName object declared by an ImportSpec.
503+
// TODO(adonovan): use go1.22's Info.PkgNameOf.
504+
func importedPkgName(info *types.Info, imp *ast.ImportSpec) (*types.PkgName, bool) {
505+
var obj types.Object
506+
if imp.Name != nil {
507+
obj = info.Defs[imp.Name]
508+
} else {
509+
obj = info.Implicits[imp]
510+
}
511+
pkgname, ok := obj.(*types.PkgName)
512+
return pkgname, ok
513+
}

0 commit comments

Comments
 (0)