diff --git a/shared/keystore/keystore.go b/shared/keystore/keystore.go index 637c938b2..6849074d3 100644 --- a/shared/keystore/keystore.go +++ b/shared/keystore/keystore.go @@ -28,6 +28,7 @@ import ( "fmt" "io" "io/ioutil" + "os" "path/filepath" "strings" @@ -96,6 +97,15 @@ func (ks Store) GetKeys(directory, fileprefix, password string, warnOnFail bool) n := f.Name() filePath := filepath.Join(directory, n) filePath = filepath.Clean(filePath) + if f.Mode()&os.ModeSymlink == os.ModeSymlink { + if targetFilePath, err := filepath.EvalSymlinks(filePath); err == nil { + filePath = targetFilePath + // Override link stats with target file's stats. + if f, err = os.Stat(filePath); err != nil { + return nil, err + } + } + } cp := strings.Contains(n, strings.TrimPrefix(fileprefix, "/")) if f.Mode().IsRegular() && cp { // #nosec G304 diff --git a/shared/keystore/keystore_test.go b/shared/keystore/keystore_test.go index eac8ca0d2..b9ab9efe8 100644 --- a/shared/keystore/keystore_test.go +++ b/shared/keystore/keystore_test.go @@ -122,3 +122,42 @@ func TestEncryptDecryptKey(t *testing.T) { } } + +func TestGetSymlinkedKeys(t *testing.T) { + tmpdir := testutil.TempDir() + "/symlinked-keystore" + defer func() { + if err := os.RemoveAll(tmpdir); err != nil { + t.Logf("unable to remove temporary files: %v", err) + } + }() + ks := &Store{ + scryptN: LightScryptN, + scryptP: LightScryptP, + } + + key, err := NewKey() + if err != nil { + t.Fatalf("key generation failed %v", err) + } + + if err := ks.StoreKey(tmpdir+"/files/test-1", key, "password"); err != nil { + t.Fatalf("unable to store key %v", err) + } + + if err := os.Symlink(tmpdir+"/files/test-1", tmpdir+"/test-1"); err != nil { + t.Fatalf("unable to create symlink: %v", err) + } + + newkeys, err := ks.GetKeys(tmpdir, "test", "password", false) + if err != nil { + t.Fatalf("unable to get key %v", err) + } + if len(newkeys) != 1 { + t.Errorf("unexpected number of keys returned, want: %d, got: %d", 1, len(newkeys)) + } + for _, s := range newkeys { + if !bytes.Equal(s.SecretKey.Marshal(), key.SecretKey.Marshal()) { + t.Fatalf("retrieved secret keys are not equal %v ", s.SecretKey.Marshal()) + } + } +}