Skip to content

Commit 985c181

Browse files
committed
feat(golang-rewrite): shim generation 2
* Rename `versions.Installed` function to `IsInstalled` * Create `versions.Installed` function * Create `shims.GenerateForPluginVersions` and `shims.GenerateAll` functions * Address linter warnings * Create asdf reshim command * Run asdf hook from new `shims` functions
1 parent 62c2ba1 commit 985c181

File tree

5 files changed

+273
-26
lines changed

5 files changed

+273
-26
lines changed

cmd/cmd.go

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"asdf/config"
1212
"asdf/internal/info"
13+
"asdf/internal/shims"
1314
"asdf/internal/versions"
1415
"asdf/plugins"
1516

@@ -135,6 +136,12 @@ func Execute(version string) {
135136
},
136137
},
137138
},
139+
{
140+
Name: "reshim",
141+
Action: func(_ *cli.Context) error {
142+
return reshimCommand(logger)
143+
},
144+
},
138145
},
139146
Action: func(_ *cli.Context) error {
140147
// TODO: flesh this out
@@ -287,12 +294,15 @@ func installCommand(logger *log.Logger, toolName, version string) error {
287294
errs := versions.InstallAll(conf, dir, os.Stdout, os.Stderr)
288295
if len(errs) > 0 {
289296
for _, err := range errs {
290-
// write error stderr
291297
os.Stderr.Write([]byte(err.Error()))
292298
os.Stderr.Write([]byte("\n"))
293299
}
294300

295-
return errs[0]
301+
filtered := filterInstallErrors(errs)
302+
if len(filtered) > 0 {
303+
return filtered[0]
304+
}
305+
return nil
296306
}
297307
} else {
298308
// Install specific version
@@ -321,6 +331,16 @@ func installCommand(logger *log.Logger, toolName, version string) error {
321331
return err
322332
}
323333

334+
func filterInstallErrors(errs []error) []error {
335+
var filtered []error
336+
for _, err := range errs {
337+
if _, ok := err.(versions.NoVersionSetError); !ok {
338+
filtered = append(filtered, err)
339+
}
340+
}
341+
return filtered
342+
}
343+
324344
func parseInstallVersion(version string) (string, string) {
325345
segments := strings.Split(version, ":")
326346
if len(segments) > 1 && segments[0] == "latest" {
@@ -369,6 +389,26 @@ func latestCommand(logger *log.Logger, all bool, toolName, pattern string) (err
369389
return nil
370390
}
371391

392+
func reshimCommand(logger *log.Logger) (err error) {
393+
conf, err := config.LoadConfig()
394+
if err != nil {
395+
logger.Printf("error loading config: %s", err)
396+
return err
397+
}
398+
399+
err = shims.RemoveAll(conf)
400+
if err != nil {
401+
return err
402+
}
403+
404+
err = shims.GenerateAll(conf, os.Stdout, os.Stderr)
405+
if err != nil {
406+
return err
407+
}
408+
409+
return err
410+
}
411+
372412
func latestForPlugin(conf config.Config, toolName, pattern string, showStatus bool) error {
373413
// show single plugin
374414
plugin := plugins.New(conf, toolName)
@@ -385,7 +425,7 @@ func latestForPlugin(conf config.Config, toolName, pattern string, showStatus bo
385425
}
386426

387427
if showStatus {
388-
installed := versions.Installed(conf, plugin, latest)
428+
installed := versions.IsInstalled(conf, plugin, latest)
389429
fmt.Printf("%s\t%s\t%s\n", plugin.Name, latest, installedStatus(installed))
390430
} else {
391431
fmt.Printf("%s\n", latest)

internal/shims/shims.go

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,80 @@ package shims
33

44
import (
55
"fmt"
6+
"io"
67
"os"
8+
"path"
79
"path/filepath"
810
"strings"
911

1012
"asdf/config"
13+
"asdf/hook"
1114
"asdf/internal/toolversions"
1215
"asdf/internal/versions"
1316
"asdf/plugins"
1417

1518
"golang.org/x/sys/unix"
1619
)
1720

21+
const shimDirName = "shims"
22+
23+
// RemoveAll removes all shim scripts
24+
func RemoveAll(conf config.Config) error {
25+
shimDir := filepath.Join(conf.DataDir, shimDirName)
26+
entries, err := os.ReadDir(shimDir)
27+
if err != nil {
28+
return err
29+
}
30+
31+
for _, entry := range entries {
32+
os.RemoveAll(path.Join(shimDir, entry.Name()))
33+
}
34+
35+
return nil
36+
}
37+
38+
// GenerateAll generates shims for all executables of every version of every
39+
// plugin.
40+
func GenerateAll(conf config.Config, stdOut io.Writer, stdErr io.Writer) error {
41+
plugins, err := plugins.List(conf, false, false)
42+
if err != nil {
43+
return err
44+
}
45+
46+
for _, plugin := range plugins {
47+
err := GenerateForPluginVersions(conf, plugin, stdOut, stdErr)
48+
if err != nil {
49+
return err
50+
}
51+
}
52+
53+
return nil
54+
}
55+
56+
// GenerateForPluginVersions generates all shims for all installed versions of
57+
// a tool.
58+
func GenerateForPluginVersions(conf config.Config, plugin plugins.Plugin, stdOut io.Writer, stdErr io.Writer) error {
59+
installedVersions, err := versions.Installed(conf, plugin)
60+
if err != nil {
61+
return err
62+
}
63+
64+
for _, version := range installedVersions {
65+
err = hook.RunWithOutput(conf, fmt.Sprintf("pre_asdf_reshim_%s", plugin.Name), []string{version}, stdOut, stdErr)
66+
if err != nil {
67+
return err
68+
}
69+
70+
GenerateForVersion(conf, plugin, version)
71+
72+
err = hook.RunWithOutput(conf, fmt.Sprintf("post_asdf_reshim_%s", plugin.Name), []string{version}, stdOut, stdErr)
73+
if err != nil {
74+
return err
75+
}
76+
}
77+
return nil
78+
}
79+
1880
// GenerateForVersion loops over all the executable files found for a tool and
1981
// generates a shim for each one
2082
func GenerateForVersion(conf config.Config, plugin plugins.Plugin, version string) error {
@@ -59,11 +121,11 @@ func Write(conf config.Config, plugin plugins.Plugin, version, executablePath st
59121

60122
// Path returns the path for a shim script
61123
func Path(conf config.Config, shimName string) string {
62-
return filepath.Join(conf.DataDir, "shims", shimName)
124+
return filepath.Join(conf.DataDir, shimDirName, shimName)
63125
}
64126

65127
func ensureShimDirExists(conf config.Config) error {
66-
return os.MkdirAll(filepath.Join(conf.DataDir, "shims"), 0o777)
128+
return os.MkdirAll(filepath.Join(conf.DataDir, shimDirName), 0o777)
67129
}
68130

69131
// ToolExecutables returns a slice of executables for a given tool version
@@ -77,20 +139,20 @@ func ToolExecutables(conf config.Config, plugin plugins.Plugin, version string)
77139
paths := dirsToPaths(dirs, installPath)
78140

79141
for _, path := range paths {
80-
// Walk the directory and any sub directories
81-
err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
82-
if err != nil {
83-
return err
84-
}
85-
142+
entries, err := os.ReadDir(path)
143+
if err != nil {
144+
return executables, err
145+
}
146+
for _, entry := range entries {
86147
// If entry is dir or cannot be executed by the current user ignore it
87-
if info.IsDir() || unix.Access(path, unix.X_OK) != nil {
88-
return nil
148+
filePath := filepath.Join(path, entry.Name())
149+
if entry.IsDir() || unix.Access(filePath, unix.X_OK) != nil {
150+
return executables, nil
89151
}
90152

91-
executables = append(executables, path)
92-
return nil
93-
})
153+
executables = append(executables, filePath)
154+
return executables, nil
155+
}
94156
if err != nil {
95157
return executables, err
96158
}

internal/shims/shims_test.go

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package shims
22

33
import (
4+
"errors"
5+
"fmt"
46
"os"
57
"path/filepath"
68
"strings"
@@ -17,6 +19,93 @@ import (
1719

1820
const testPluginName = "lua"
1921

22+
func TestRemoveAll(t *testing.T) {
23+
version := "1.1.0"
24+
conf, plugin := generateConfig(t)
25+
installVersion(t, conf, plugin, version)
26+
executables, err := ToolExecutables(conf, plugin, version)
27+
assert.Nil(t, err)
28+
stdout, stderr := buildOutputs()
29+
30+
t.Run("removes all files in shim directory", func(t *testing.T) {
31+
assert.Nil(t, GenerateAll(conf, &stdout, &stderr))
32+
assert.Nil(t, RemoveAll(conf))
33+
34+
// check for generated shims
35+
for _, executable := range executables {
36+
_, err := os.Stat(Path(conf, filepath.Base(executable)))
37+
assert.True(t, errors.Is(err, os.ErrNotExist))
38+
}
39+
})
40+
}
41+
42+
func TestGenerateAll(t *testing.T) {
43+
version := "1.1.0"
44+
version2 := "2.0.0"
45+
conf, plugin := generateConfig(t)
46+
installVersion(t, conf, plugin, version)
47+
installPlugin(t, conf, "dummy_plugin", "ruby")
48+
installVersion(t, conf, plugin, version2)
49+
executables, err := ToolExecutables(conf, plugin, version)
50+
assert.Nil(t, err)
51+
stdout, stderr := buildOutputs()
52+
53+
t.Run("generates shim script for every executable in every version of every tool", func(t *testing.T) {
54+
assert.Nil(t, GenerateAll(conf, &stdout, &stderr))
55+
56+
// check for generated shims
57+
for _, executable := range executables {
58+
shimName := filepath.Base(executable)
59+
shimPath := Path(conf, shimName)
60+
assert.Nil(t, unix.Access(shimPath, unix.X_OK))
61+
62+
// shim exists and has expected contents
63+
content, err := os.ReadFile(shimPath)
64+
assert.Nil(t, err)
65+
want := fmt.Sprintf("#!/usr/bin/env bash\n# asdf-plugin: lua 2.0.0\n# asdf-plugin: lua 1.1.0\nexec asdf exec \"%s\" \"$@\"", shimName)
66+
assert.Equal(t, want, string(content))
67+
}
68+
})
69+
}
70+
71+
func TestGenerateForPluginVersions(t *testing.T) {
72+
t.Setenv("ASDF_CONFIG_FILE", "testdata/asdfrc")
73+
version := "1.1.0"
74+
version2 := "2.0.0"
75+
conf, plugin := generateConfig(t)
76+
installVersion(t, conf, plugin, version)
77+
installVersion(t, conf, plugin, version2)
78+
executables, err := ToolExecutables(conf, plugin, version)
79+
assert.Nil(t, err)
80+
stdout, stderr := buildOutputs()
81+
82+
t.Run("generates shim script for every executable in every version the tool", func(t *testing.T) {
83+
assert.Nil(t, GenerateForPluginVersions(conf, plugin, &stdout, &stderr))
84+
85+
// check for generated shims
86+
for _, executable := range executables {
87+
shimName := filepath.Base(executable)
88+
shimPath := Path(conf, shimName)
89+
assert.Nil(t, unix.Access(shimPath, unix.X_OK))
90+
91+
// shim exists and has expected contents
92+
content, err := os.ReadFile(shimPath)
93+
assert.Nil(t, err)
94+
95+
want := fmt.Sprintf("#!/usr/bin/env bash\n# asdf-plugin: lua 2.0.0\n# asdf-plugin: lua 1.1.0\nexec asdf exec \"%s\" \"$@\"", shimName)
96+
assert.Equal(t, want, string(content))
97+
}
98+
})
99+
100+
t.Run("runs pre and post reshim hooks", func(t *testing.T) {
101+
stdout, stderr := buildOutputs()
102+
assert.Nil(t, GenerateForPluginVersions(conf, plugin, &stdout, &stderr))
103+
104+
want := "pre_reshim 1.1.0\npost_reshim 1.1.0\npre_reshim 2.0.0\npost_reshim 2.0.0\n"
105+
assert.Equal(t, want, stdout.String())
106+
})
107+
}
108+
20109
func TestGenerateForVersion(t *testing.T) {
21110
version := "1.1.0"
22111
version2 := "2.0.0"
@@ -125,7 +214,7 @@ func TestToolExecutables(t *testing.T) {
125214
filenames = append(filenames, filepath.Base(executablePath))
126215
}
127216

128-
assert.Equal(t, filenames, []string{"dummy", "other_bin"})
217+
assert.Equal(t, filenames, []string{"dummy"})
129218
})
130219
}
131220

@@ -165,10 +254,14 @@ func generateConfig(t *testing.T) (config.Config, plugins.Plugin) {
165254
assert.Nil(t, err)
166255
conf.DataDir = testDataDir
167256

168-
_, err = repotest.InstallPlugin("dummy_plugin", testDataDir, testPluginName)
257+
return conf, installPlugin(t, conf, "dummy_plugin", testPluginName)
258+
}
259+
260+
func installPlugin(t *testing.T, conf config.Config, fixture, pluginName string) plugins.Plugin {
261+
_, err := repotest.InstallPlugin(fixture, conf.DataDir, pluginName)
169262
assert.Nil(t, err)
170263

171-
return conf, plugins.New(conf, testPluginName)
264+
return plugins.New(conf, testPluginName)
172265
}
173266

174267
func installVersion(t *testing.T, conf config.Config, plugin plugins.Plugin, version string) {

0 commit comments

Comments
 (0)