diff --git a/internal/pkg/agent/application/upgrade/step_unpack.go b/internal/pkg/agent/application/upgrade/step_unpack.go index 219c2cb154a..a422dea4912 100644 --- a/internal/pkg/agent/application/upgrade/step_unpack.go +++ b/internal/pkg/agent/application/upgrade/step_unpack.go @@ -34,7 +34,6 @@ type UnpackResult struct { } type copyFunc func(dst io.Writer, src io.Reader) (written int64, err error) -type mkdirAllFunc func(name string, perm fs.FileMode) error type openFileFunc func(name string, flag int, perm fs.FileMode) (*os.File, error) type unarchiveFunc func(log *logger.Logger, archivePath, dataDir string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error) diff --git a/internal/pkg/agent/application/upgrade/upgrade.go b/internal/pkg/agent/application/upgrade/upgrade.go index bc1a8061f18..f61131a3602 100644 --- a/internal/pkg/agent/application/upgrade/upgrade.go +++ b/internal/pkg/agent/application/upgrade/upgrade.go @@ -67,6 +67,16 @@ type unpackHandler interface { getPackageMetadata(archivePath string) (packageMetadata, error) } +// Types used to abstract copyActionStore, copyRunDirectory and github.com/otiai10/copy.Copy +type copyActionStoreFunc func(log *logger.Logger, newHome string) error +type copyRunDirectoryFunc func(log *logger.Logger, oldRunPath, newRunPath string) error +type fileDirCopyFunc func(from, to string, opts ...copy.Options) error + +// Types used to abstract stdlib functions +type mkdirAllFunc func(name string, perm fs.FileMode) error +type readFileFunc func(name string) ([]byte, error) +type writeFileFunc func(name string, data []byte, perm fs.FileMode) error + // Upgrader performs an upgrade type Upgrader struct { log *logger.Logger @@ -80,7 +90,8 @@ type Upgrader struct { artifactDownloader artifactDownloadHandler unpacker unpackHandler isDiskSpaceErrorFunc func(err error) bool - extractAgentVersion func(metadata packageMetadata, upgradeVersion string) agentVersion + copyActionStore copyActionStoreFunc + copyRunDirectory copyRunDirectoryFunc } // IsUpgradeable when agent is installed and running as a service or flag was provided. @@ -101,6 +112,8 @@ func NewUpgrader(log *logger.Logger, settings *artifact.Config, agentInfo info.A artifactDownloader: newArtifactDownloader(settings, log), unpacker: newUnpacker(log), isDiskSpaceErrorFunc: upgradeErrors.IsDiskSpaceError, + copyActionStore: copyActionStoreProvider(os.ReadFile, os.WriteFile), + copyRunDirectory: copyRunDirectoryProvider(os.MkdirAll, copy.Copy), }, nil } @@ -278,15 +291,15 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, sourceURI string newHome := filepath.Join(paths.Top(), unpackRes.VersionedHome) - if err := copyActionStore(u.log, newHome); err != nil { - return nil, errors.New(err, "failed to copy action store") + if err := u.copyActionStore(u.log, newHome); err != nil { + return nil, fmt.Errorf("failed to copy action store: %w", err) } newRunPath := filepath.Join(newHome, "run") oldRunPath := filepath.Join(paths.Home(), "run") - if err := copyRunDirectory(u.log, oldRunPath, newRunPath); err != nil { - return nil, errors.New(err, "failed to copy run directory") + if err := u.copyRunDirectory(u.log, oldRunPath, newRunPath); err != nil { + return nil, fmt.Errorf("failed to copy run directory: %w", err) } det.SetState(details.StateReplacing) @@ -503,50 +516,55 @@ func rollbackInstall(ctx context.Context, log *logger.Logger, topDirPath, versio return nil } -func copyActionStore(log *logger.Logger, newHome string) error { - // copies legacy action_store.yml, state.yml and state.enc encrypted file if exists - storePaths := []string{paths.AgentActionStoreFile(), paths.AgentStateStoreYmlFile(), paths.AgentStateStoreFile()} - log.Infow("Copying action store", "new_home_path", newHome) +func copyActionStoreProvider(readFile readFileFunc, writeFile writeFileFunc) copyActionStoreFunc { + return func(log *logger.Logger, newHome string) error { + // copies legacy action_store.yml, state.yml and state.enc encrypted file if exists + storePaths := []string{paths.AgentActionStoreFile(), paths.AgentStateStoreYmlFile(), paths.AgentStateStoreFile()} + log.Infow("Copying action store", "new_home_path", newHome) + + for _, currentActionStorePath := range storePaths { + newActionStorePath := filepath.Join(newHome, filepath.Base(currentActionStorePath)) + log.Infow("Copying action store path", "from", currentActionStorePath, "to", newActionStorePath) + // using readfile instead of os.ReadFile for testability + currentActionStore, err := readFile(currentActionStorePath) + if os.IsNotExist(err) { + // nothing to copy + continue + } + if err != nil { + return err + } - for _, currentActionStorePath := range storePaths { - newActionStorePath := filepath.Join(newHome, filepath.Base(currentActionStorePath)) - log.Infow("Copying action store path", "from", currentActionStorePath, "to", newActionStorePath) - currentActionStore, err := os.ReadFile(currentActionStorePath) - if os.IsNotExist(err) { - // nothing to copy - continue - } - if err != nil { - return err + // using writeFile instead of os.WriteFile for testability + if err := writeFile(newActionStorePath, currentActionStore, 0o600); err != nil { + return fmt.Errorf("failed to write action store at %q: %w", newActionStorePath, err) + } } - if err := os.WriteFile(newActionStorePath, currentActionStore, 0o600); err != nil { - return err - } + return nil } - - return nil } -func copyRunDirectory(log *logger.Logger, oldRunPath, newRunPath string) error { +func copyRunDirectoryProvider(mkdirAll mkdirAllFunc, fileDirCopy fileDirCopyFunc) copyRunDirectoryFunc { + return func(log *logger.Logger, oldRunPath, newRunPath string) error { + log.Infow("Copying run directory", "new_run_path", newRunPath, "old_run_path", oldRunPath) - log.Infow("Copying run directory", "new_run_path", newRunPath, "old_run_path", oldRunPath) + if err := mkdirAll(newRunPath, runDirMod); err != nil { + return fmt.Errorf("failed to create run directory: %w", err) + } - if err := os.MkdirAll(newRunPath, runDirMod); err != nil { - return errors.New(err, "failed to create run directory") - } + err := copyDir(log, oldRunPath, newRunPath, true, fileDirCopy) + if os.IsNotExist(err) { + // nothing to copy, operation ok + log.Infow("Run directory not present", "old_run_path", oldRunPath) + return nil + } + if err != nil { + return fmt.Errorf("failed to copy %q to %q: %w", oldRunPath, newRunPath, err) + } - err := copyDir(log, oldRunPath, newRunPath, true) - if os.IsNotExist(err) { - // nothing to copy, operation ok - log.Infow("Run directory not present", "old_run_path", oldRunPath) return nil } - if err != nil { - return errors.New(err, "failed to copy %q to %q", oldRunPath, newRunPath) - } - - return nil } // shutdownCallback returns a callback function to be executing during shutdown once all processes are closed. @@ -574,7 +592,7 @@ func shutdownCallback(l *logger.Logger, homePath, prevVersion, newVersion, newHo newRelPath = strings.ReplaceAll(newRelPath, oldHome, newHome) newDir := filepath.Join(newHome, newRelPath) l.Debugf("copying %q -> %q", processDir, newDir) - if err := copyDir(l, processDir, newDir, true); err != nil { + if err := copyDir(l, processDir, newDir, true, copy.Copy); err != nil { return err } } @@ -620,7 +638,7 @@ func readDirs(dir string) ([]string, error) { return dirs, nil } -func copyDir(l *logger.Logger, from, to string, ignoreErrs bool) error { +func copyDir(l *logger.Logger, from, to string, ignoreErrs bool, fileDirCopy fileDirCopyFunc) error { var onErr func(src, dst string, err error) error if ignoreErrs { @@ -646,7 +664,7 @@ func copyDir(l *logger.Logger, from, to string, ignoreErrs bool) error { copyConcurrency = runtime.NumCPU() * 4 } - return copy.Copy(from, to, copy.Options{ + return fileDirCopy(from, to, copy.Options{ OnSymlink: func(_ string) copy.SymlinkAction { return copy.Shallow }, diff --git a/internal/pkg/agent/application/upgrade/upgrade_test.go b/internal/pkg/agent/application/upgrade/upgrade_test.go index 2d77524edb2..52cbe38d972 100644 --- a/internal/pkg/agent/application/upgrade/upgrade_test.go +++ b/internal/pkg/agent/application/upgrade/upgrade_test.go @@ -9,6 +9,7 @@ import ( "crypto/tls" "fmt" "io" + "io/fs" "net/http" "net/url" "os" @@ -19,6 +20,7 @@ import ( "time" "github.com/gofrs/flock" + "github.com/otiai10/copy" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -117,7 +119,7 @@ func Test_CopyFile(t *testing.T) { } - err := copyDir(l, tc.From, tc.To, tc.IgnoreErr) + err := copyDir(l, tc.From, tc.To, tc.IgnoreErr, copy.Copy) require.Equal(t, tc.ExpectedErr, err != nil, err) }) } @@ -1200,13 +1202,6 @@ func TestUpgradeErrorHandling(t *testing.T) { expectedError: testError, upgraderMocker: func(upgrader *Upgrader) { upgrader.artifactDownloader = &mockArtifactDownloader{} - upgrader.extractAgentVersion = func(metadata packageMetadata, upgradeVersion string) agentVersion { - return agentVersion{ - version: upgradeVersion, - snapshot: false, - hash: metadata.hash, - } - } upgrader.unpacker = &mockUnpacker{ returnPackageMetadata: packageMetadata{ manifest: &v1.PackageManifest{}, @@ -1216,6 +1211,50 @@ func TestUpgradeErrorHandling(t *testing.T) { } }, }, + "should return error if copyActionStore fails": { + isDiskSpaceErrorResult: false, + expectedError: testError, + upgraderMocker: func(upgrader *Upgrader) { + upgrader.artifactDownloader = &mockArtifactDownloader{} + upgrader.unpacker = &mockUnpacker{ + returnPackageMetadata: packageMetadata{ + manifest: &v1.PackageManifest{}, + hash: "hash", + }, + returnUnpackResult: UnpackResult{ + Hash: "hash", + VersionedHome: "versionedHome", + }, + } + upgrader.copyActionStore = func(log *logger.Logger, newHome string) error { + return testError + } + }, + }, + "should return error if copyRunDirectory fails": { + isDiskSpaceErrorResult: false, + expectedError: testError, + upgraderMocker: func(upgrader *Upgrader) { + upgrader.artifactDownloader = &mockArtifactDownloader{} + upgrader.artifactDownloader = &mockArtifactDownloader{} + upgrader.unpacker = &mockUnpacker{ + returnPackageMetadata: packageMetadata{ + manifest: &v1.PackageManifest{}, + hash: "hash", + }, + returnUnpackResult: UnpackResult{ + Hash: "hash", + VersionedHome: "versionedHome", + }, + } + upgrader.copyActionStore = func(log *logger.Logger, newHome string) error { + return nil + } + upgrader.copyRunDirectory = func(log *logger.Logger, oldRunPath, newRunPath string) error { + return testError + } + }, + }, "should add disk space error to the error chain if downloadArtifact fails with disk space error": { isDiskSpaceErrorResult: true, expectedError: upgradeErrors.ErrInsufficientDiskSpace, @@ -1266,3 +1305,188 @@ func TestSetClient(t *testing.T) { upgrader.SetClient(&mockSender{}) require.Equal(t, "mockURI", upgrader.artifactDownloader.(*mockArtifactDownloader).fleetServerURI) } + +func TestCopyActionStore(t *testing.T) { + log, _ := loggertest.New("TestCopyActionStore") + + actionStoreContent := "initial agent action_store.yml content" + actionStateStoreYamlContent := "initial agent state.yml content" + actionStateStoreFileContent := "initial agent state.enc content" + + type testFile struct { + name string + content string + } + + type testCase struct { + files []testFile + copyActionStore copyActionStoreFunc + expectedError error + } + + testError := errors.New("test error") + + testCases := map[string]testCase{ + "should copy all action store files": { + files: []testFile{ + {name: "action_store", content: actionStoreContent}, + {name: "state_yaml", content: actionStateStoreYamlContent}, + }, + copyActionStore: copyActionStoreProvider(os.ReadFile, os.WriteFile), + expectedError: nil, + }, + "should skip copying action store file that does not exist": { + files: []testFile{ + {name: "action_store", content: actionStoreContent}, + {name: "state_yaml", content: actionStateStoreYamlContent}, + }, + copyActionStore: copyActionStoreProvider(os.ReadFile, os.WriteFile), + expectedError: nil, + }, + "should return error if it cannot read the action store files": { + files: []testFile{ + {name: "action_store", content: actionStoreContent}, + {name: "state_yaml", content: actionStateStoreYamlContent}, + {name: "state_enc", content: actionStateStoreFileContent}, + }, + copyActionStore: copyActionStoreProvider(func(name string) ([]byte, error) { + return nil, testError + }, os.WriteFile), + expectedError: testError, + }, + "should return error if it cannot write the action store files": { + files: []testFile{ + {name: "action_store", content: actionStoreContent}, + {name: "state_yaml", content: actionStateStoreYamlContent}, + {name: "state_enc", content: actionStateStoreFileContent}, + }, + copyActionStore: copyActionStoreProvider(os.ReadFile, func(name string, data []byte, perm os.FileMode) error { + return testError + }), + expectedError: testError, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + baseDir := t.TempDir() + newHome := filepath.Join(baseDir, "new_home") + paths.SetTop(baseDir) + + actionStorePath := paths.AgentActionStoreFile() + actionStateStoreYamlPath := paths.AgentStateStoreYmlFile() + actionStateStoreFilePath := paths.AgentStateStoreFile() + + newActionStorePaths := []string{} + + for _, file := range testCase.files { + path := "" + + switch file.name { + case "action_store": + path = actionStorePath + case "state_yaml": + path = actionStateStoreYamlPath + case "state_enc": + path = actionStateStoreFilePath + } + + // Create the action store directories and files + dir := filepath.Dir(path) + err := os.MkdirAll(dir, 0o755) + require.NoError(t, err, "error creating directory %s", dir) + + err = os.WriteFile(path, []byte(file.content), 0o600) + require.NoError(t, err, "error writing to %s", path) + + // Create the new action store directories + newActionStorePath := filepath.Join(newHome, filepath.Base(path)) + newActionStorePaths = append(newActionStorePaths, newActionStorePath) + err = os.MkdirAll(filepath.Dir(newActionStorePath), 0o755) + require.NoError(t, err, "error creating directory %s", filepath.Dir(newActionStorePath)) + } + + err := testCase.copyActionStore(log, newHome) + if testCase.expectedError != nil { + require.Error(t, err, "copyActionStoreFunc should return error") + require.ErrorIs(t, err, testCase.expectedError, "copyActionStoreFunc error mismatch") + return + } + + require.NoError(t, err, "error copying action store") + + for i, path := range newActionStorePaths { + require.FileExists(t, path, "file %s does not exist", path) + + content, err := os.ReadFile(path) + require.NoError(t, err, "error reading from %s", path) + require.Equal(t, []byte(testCase.files[i].content), content, "content of %s is not as expected", path) + } + }) + } +} + +func TestCopyRunDirectory(t *testing.T) { + log, _ := loggertest.New("TestCopyRunDirectory") + + type testCase struct { + expectedError error + copyRunDirectory copyRunDirectoryFunc + } + + testCases := map[string]testCase{ + "should copy old run directory to new run directory": { + expectedError: nil, + copyRunDirectory: copyRunDirectoryProvider(os.MkdirAll, copy.Copy), + }, + "should return error if it cannot create the new run directory": { + expectedError: fs.ErrPermission, + copyRunDirectory: copyRunDirectoryProvider(func(path string, perm os.FileMode) error { + return fs.ErrPermission + }, copy.Copy), + }, + "should return error if it cannot copy the old run directory": { + expectedError: errors.New("test error"), + copyRunDirectory: copyRunDirectoryProvider(os.MkdirAll, func(src, dest string, opts ...copy.Options) error { + return errors.New("test error") + }), + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + baseDir := t.TempDir() + paths.SetTop(baseDir) + + oldRunPath := filepath.Join(baseDir, "old_dir", "run") + oldRunFile := filepath.Join(oldRunPath, "file.txt") + + err := os.MkdirAll(oldRunPath, 0o700) + require.NoError(t, err, "error creating old run directory") + + err = os.WriteFile(oldRunFile, []byte("content for old run file"), 0o600) + require.NoError(t, err, "error writing to %s", oldRunFile) + + newRunPath := filepath.Join(baseDir, "new_dir", "run") + + err = os.MkdirAll(newRunPath, 0o700) + require.NoError(t, err, "error creating new run directory") + + err = testCase.copyRunDirectory(log, oldRunPath, newRunPath) + if testCase.expectedError != nil { + require.Error(t, err, "copyRunDirectoryFunc should return error") + require.ErrorIs(t, err, testCase.expectedError, "copyRunDirectoryFunc should return test error") + return + } + + require.NoError(t, err, "error copying run directory") + require.DirExists(t, newRunPath, "new run directory does not exist") + + require.FileExists(t, filepath.Join(newRunPath, "file.txt"), "file.txt does not exist in new run directory") + + content, err := os.ReadFile(filepath.Join(newRunPath, "file.txt")) + require.NoError(t, err, "error reading from %s", filepath.Join(newRunPath, "file.txt")) + require.Equal(t, []byte("content for old run file"), content, "content of %s is not as expected", filepath.Join(newRunPath, "file.txt")) + }) + } +}