Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions ioutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,69 @@ import (
"fmt"
"io"
"os"
"path"
)

type TmpFile struct {
file *os.File
path string
}

func NewTmpFile() (*TmpFile, error) {
file, err := os.CreateTemp(os.TempDir(), "k8s-secret-editor-")
func NewTmpFile(suffix string) (*TmpFile, error) {
name := fmt.Sprintf("k8s-secret-editor-%s", suffix)
p := path.Join(os.TempDir(), name)
f, err := os.Create(p)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating temp file: %w", err)
}
if err := f.Close(); err != nil {
_ = os.Remove(f.Name())
return nil, fmt.Errorf("error closing temp file: %w", err)
}

// Set restrictive permissions to protect sensitive secret data
if err := os.Chmod(file.Name(), 0600); err != nil {
_ = file.Close()
_ = os.Remove(file.Name())
if err := os.Chmod(p, 0o600); err != nil {
_ = os.Remove(p)
return nil, fmt.Errorf("error setting file permissions: %w", err)
}

return &TmpFile{file: file}, nil
return &TmpFile{path: p}, nil
}

func (t *TmpFile) Write(data []byte) error {
if _, err := t.file.Write(data); err != nil {
f, err := os.OpenFile(t.path, os.O_WRONLY|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("error opening temp file for writing: %w", err)
}
if _, err := f.Write(data); err != nil {
return fmt.Errorf("error writing to temp file: %w", err)
}
if err := f.Sync(); err != nil {
return fmt.Errorf("error syncing temp file: %w", err)
}
return nil
}

func (t *TmpFile) Read() ([]byte, error) {
// Move the file pointer back to the beginning before reading
if _, err := t.file.Seek(0, io.SeekStart); err != nil {
return nil, fmt.Errorf("error seeking to beginning of temp file: %w", err)
f, err := os.Open(t.path)
if err != nil {
return nil, fmt.Errorf("error opening temp file for reading: %w", err)
}

data, err := io.ReadAll(t.file)
data, err := io.ReadAll(f)
if err != nil {
return nil, fmt.Errorf("error reading from temp file: %w", err)
}
return data, nil
}

func (t *TmpFile) OpenEditor(editor interface{ Open(filePath string) error }) error {
return editor.Open(t.file.Name())
if err := editor.Open(t.path); err != nil {
return fmt.Errorf("error opening editor: %w", err)
}
return nil
}

func (t *TmpFile) Close() error {
name := t.file.Name()
if err := t.file.Close(); err != nil {
return err
}
// Remove temp file after closing to avoid leaking secrets
if err := os.Remove(name); err != nil {
if err := os.Remove(t.path); err != nil {
return fmt.Errorf("error removing temp file: %w", err)
}
return nil
Expand Down
86 changes: 33 additions & 53 deletions ioutils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,10 @@ import (
)

func TestNewTmpFile(t *testing.T) {
tmp, err := NewTmpFile()
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer tmp.file.Close()
defer os.Remove(tmp.file.Name())
tmp := newTestTmpFile(t)

// Verify file exists
fi, err := os.Stat(tmp.file.Name())
fi, err := os.Stat(tmp.path)
if err != nil {
t.Fatalf("failed to stat temp file: %v", err)
}
Expand All @@ -26,14 +21,9 @@ func TestNewTmpFile(t *testing.T) {
}

func TestTmpFilePermissions(t *testing.T) {
tmp, err := NewTmpFile()
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer tmp.file.Close()
defer os.Remove(tmp.file.Name())
tmp := newTestTmpFile(t)

fi, err := os.Stat(tmp.file.Name())
fi, err := os.Stat(tmp.path)
if err != nil {
t.Fatalf("failed to stat temp file: %v", err)
}
Expand All @@ -47,28 +37,20 @@ func TestTmpFilePermissions(t *testing.T) {
}

func TestTmpFileWrite(t *testing.T) {
tmp, err := NewTmpFile()
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer tmp.Close()
tmp := newTestTmpFile(t)

testData := []byte("test data for secret")
err = tmp.Write(testData)
err := tmp.Write(testData)
if err != nil {
t.Fatalf("failed to write to temp file: %v", err)
}
}

func TestTmpFileRead(t *testing.T) {
tmp, err := NewTmpFile()
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer tmp.Close()
tmp := newTestTmpFile(t)

testData := []byte("test secret data")
err = tmp.Write(testData)
err := tmp.Write(testData)
if err != nil {
t.Fatalf("failed to write to temp file: %v", err)
}
Expand All @@ -85,16 +67,12 @@ func TestTmpFileRead(t *testing.T) {
}

func TestTmpFileReadAfterMultipleWrites(t *testing.T) {
tmp, err := NewTmpFile()
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer tmp.Close()
tmp := newTestTmpFile(t)

// Write multiple times
data1 := []byte("first")
data2 := []byte("second")
err = tmp.Write(data1)
err := tmp.Write(data1)
if err != nil {
t.Fatalf("failed to write first data: %v", err)
}
Expand All @@ -104,28 +82,22 @@ func TestTmpFileReadAfterMultipleWrites(t *testing.T) {
t.Fatalf("failed to write second data: %v", err)
}

// Read should return both
readData, err := tmp.Read()
if err != nil {
t.Fatalf("failed to read from temp file: %v", err)
}

expected := "firstsecond"
expected := "second"
if string(readData) != expected {
t.Errorf("expected data %s, got %s", expected, string(readData))
}
}

func TestTmpFileClose(t *testing.T) {
tmp, err := NewTmpFile()
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}

filePath := tmp.file.Name()
tmp := newTestTmpFile(t)

// Write some data
err = tmp.Write([]byte("test"))
err := tmp.Write([]byte("test"))
if err != nil {
t.Fatalf("failed to write to temp file: %v", err)
}
Expand All @@ -137,20 +109,17 @@ func TestTmpFileClose(t *testing.T) {
}

// Verify file is deleted
_, err = os.Stat(filePath)
_, err = os.Stat(tmp.path)
if !os.IsNotExist(err) {
t.Errorf("temp file was not deleted: %s", filePath)
t.Errorf("temp file was not deleted: %s", tmp.path)
}
}

func TestTmpFileCloseMultipleTimes(t *testing.T) {
tmp, err := NewTmpFile()
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
tmp := newTestTmpFile(t)

// First close
err = tmp.Close()
err := tmp.Close()
if err != nil {
t.Fatalf("first close failed: %v", err)
}
Expand All @@ -163,11 +132,7 @@ func TestTmpFileCloseMultipleTimes(t *testing.T) {
}

func TestTmpFileOpenEditor(t *testing.T) {
tmp, err := NewTmpFile()
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer tmp.Close()
tmp := newTestTmpFile(t)

editorPath := "/bin/sh"
if _, err := os.Stat(editorPath); os.IsNotExist(err) {
Expand All @@ -185,3 +150,18 @@ func TestTmpFileOpenEditor(t *testing.T) {
t.Logf("OpenEditor returned error: %v (expected for sh with no input)", err)
}
}

func newTestTmpFile(t *testing.T) *TmpFile {
t.Helper()

tmp, err := NewTmpFile("test")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
t.Cleanup(func() {
if err := tmp.Close(); err != nil {
t.Logf("failed to clean up temp file: %v", err)
}
})
return tmp
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func main() { //nolint:gocyclo
fatalf("Key '%s' not found in secret '%s' in namespace '%s'", selectedKey, selectedSecret, selectedNamespace)
}

tmpFile, err := NewTmpFile()
tmpFile, err := NewTmpFile(selectedKey)
if err != nil {
fatalf("Error creating temp file: %v", err)
}
Expand Down
Loading