Skip to content

Add sync.Mutex and os.Rename to prevent corrupted file when downloading the Postgres archive #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
May 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
118f8fb
Add sync.Mutex to prevent collisions
alecsammon Mar 7, 2023
b67856e
* add atomic download, and use defer to ensure mutex unlock
alecsammon Mar 8, 2023
1f9492f
* move mutex to global
alecsammon Mar 8, 2023
53177fc
* fix tests
alecsammon Mar 8, 2023
7918aaa
* update platform-test
alecsammon Mar 8, 2023
7eaa871
* update examples
alecsammon Mar 8, 2023
f81ff71
* remove atomic dependency
alecsammon Mar 11, 2023
3ea19a9
* remove code duplication
alecsammon Mar 11, 2023
ea01184
* reduce test parallel run count
alecsammon Mar 11, 2023
61f98c3
* fix race condition in tests
alecsammon Mar 11, 2023
d7dc5f4
* run tests
alecsammon Mar 11, 2023
fe4e475
* attempt to fix windows
alecsammon Mar 11, 2023
ad1a924
* attempt a different solution for windows
alecsammon Mar 11, 2023
2f1edc7
* revert changes
alecsammon Mar 11, 2023
4b23ae8
* try additional fix for windows
alecsammon Mar 16, 2023
a0d4846
* add another test
alecsammon Mar 16, 2023
f335964
* catch syscall.EEXIST
alecsammon Mar 16, 2023
3c7ed67
* fix test
alecsammon Mar 16, 2023
d60bb7f
* fix test
alecsammon Mar 22, 2023
3041195
* add extra debugging
alecsammon Mar 22, 2023
88f352e
* attempt to fix windows
alecsammon Mar 22, 2023
7425087
* add additional error message
alecsammon Mar 22, 2023
e1f9687
* fix race in decompression
alecsammon Mar 22, 2023
8e0d6b4
* more fixes
alecsammon Mar 22, 2023
a0f5116
* use atomic
alecsammon Mar 22, 2023
18c6245
* add extra debug
alecsammon Mar 22, 2023
bffd2d1
* try catching the error
alecsammon Mar 22, 2023
87300c8
* try different permissions
alecsammon Mar 22, 2023
d7964cd
* add more debugging
alecsammon Mar 22, 2023
d74f4a0
* more debug
alecsammon Mar 22, 2023
e68a942
* more debug
alecsammon Mar 22, 2023
f275198
* test dest
alecsammon Mar 22, 2023
c563225
* attempt to close temp file
alecsammon Mar 22, 2023
47b796e
* simplify
alecsammon Mar 22, 2023
76114a7
* remove atomic
alecsammon Mar 22, 2023
62710a4
* clean up code
alecsammon Mar 22, 2023
a3e6ebd
* add more tests
alecsammon Mar 22, 2023
c767c32
* clean up temporary files
alecsammon Mar 23, 2023
30c954d
* prevent file being opened twice
alecsammon Mar 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions decompression.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,19 @@ func defaultTarReader(xzReader *xz.Reader) (func() (*tar.Header, error), func()
}

func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), func() io.Reader), path, extractPath string) error {
tempExtractPath, err := os.MkdirTemp(filepath.Dir(extractPath), "temp_")
if err != nil {
return errorUnableToExtract(path, extractPath, err)
}
defer func() {
if err := os.RemoveAll(tempExtractPath); err != nil {
panic(err)
}
}()

tarFile, err := os.Open(path)
if err != nil {
return errorUnableToExtract(path, extractPath)
return errorUnableToExtract(path, extractPath, err)
}

defer func() {
Expand All @@ -34,7 +44,7 @@ func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), fu

xzReader, err := xz.NewReader(tarFile, 0)
if err != nil {
return errorUnableToExtract(path, extractPath)
return errorUnableToExtract(path, extractPath, err)
}

readNext, reader := tarReader(xzReader)
Expand All @@ -43,16 +53,21 @@ func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), fu
header, err := readNext()

if err == io.EOF {
return nil
break
}

if err != nil {
return errorExtractingPostgres(err)
}

targetPath := filepath.Join(extractPath, header.Name)
targetPath := filepath.Join(tempExtractPath, header.Name)
finalPath := filepath.Join(extractPath, header.Name)

if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
if err := os.MkdirAll(filepath.Dir(targetPath), os.ModePerm); err != nil {
return errorExtractingPostgres(err)
}

if err := os.MkdirAll(filepath.Dir(finalPath), os.ModePerm); err != nil {
return errorExtractingPostgres(err)
}

Expand All @@ -78,10 +93,26 @@ func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), fu
if err := os.Symlink(header.Linkname, targetPath); err != nil {
return errorExtractingPostgres(err)
}

case tar.TypeDir:
if err := os.MkdirAll(finalPath, os.FileMode(header.Mode)); err != nil {
return errorExtractingPostgres(err)
}
continue
}

if err := renameOrIgnore(targetPath, finalPath); err != nil {
return errorExtractingPostgres(err)
}
}

return nil
}

func errorUnableToExtract(cacheLocation, binariesPath string) error {
return fmt.Errorf("unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories", cacheLocation, binariesPath)
func errorUnableToExtract(cacheLocation, binariesPath string, err error) error {
return fmt.Errorf("unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories, %w",
cacheLocation,
binariesPath,
err,
)
}
43 changes: 42 additions & 1 deletion decompression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ package embeddedpostgres
import (
"archive/tar"
"errors"
"fmt"
"io"
"os"
"path"
"path/filepath"
"syscall"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xi2/xz"
)

Expand All @@ -17,6 +21,9 @@ func Test_decompressTarXz(t *testing.T) {
if err != nil {
panic(err)
}
if err := syscall.Rmdir(tempDir); err != nil {
panic(err)
}

archive, cleanUp := createTempXzArchive()
defer cleanUp()
Expand All @@ -37,14 +44,22 @@ func Test_decompressTarXz(t *testing.T) {
func Test_decompressTarXz_ErrorWhenFileNotExists(t *testing.T) {
err := decompressTarXz(defaultTarReader, "/does-not-exist", "/also-fake")

assert.EqualError(t, err, "unable to extract postgres archive /does-not-exist to /also-fake, if running parallel tests, configure RuntimePath to isolate testing directories")
assert.Error(t, err)
assert.Contains(
t,
err.Error(),
"unable to extract postgres archive /does-not-exist to /also-fake, if running parallel tests, configure RuntimePath to isolate testing directories",
)
}

func Test_decompressTarXz_ErrorWhenErrorDuringRead(t *testing.T) {
tempDir, err := os.MkdirTemp("", "temp_tar_test")
if err != nil {
panic(err)
}
if err := syscall.Rmdir(tempDir); err != nil {
panic(err)
}

archive, cleanUp := createTempXzArchive()
defer cleanUp()
Expand Down Expand Up @@ -103,6 +118,9 @@ func Test_decompressTarXz_ErrorWhenFileToCopyToNotExists(t *testing.T) {
if err != nil {
panic(err)
}
if err := syscall.Rmdir(tempDir); err != nil {
panic(err)
}

archive, cleanUp := createTempXzArchive()
defer cleanUp()
Expand Down Expand Up @@ -137,6 +155,9 @@ func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) {
if err != nil {
panic(err)
}
if err := syscall.Rmdir(tempDir); err != nil {
panic(err)
}

archive, cleanup := createTempXzArchive()

Expand All @@ -163,3 +184,23 @@ func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) {

assert.EqualError(t, err, "unable to extract postgres archive: xz: data is corrupt")
}

func Test_decompressTarXz_ErrorWithInvalidDestination(t *testing.T) {
archive, cleanUp := createTempXzArchive()
defer cleanUp()

tempDir, err := os.MkdirTemp("", "temp_tar_test")
require.NoError(t, err)
defer func() {
os.RemoveAll(tempDir)
}()

op := fmt.Sprintf(path.Join(tempDir, "%c"), rune(0))

err = decompressTarXz(defaultTarReader, archive, op)
assert.EqualError(
t,
err,
fmt.Sprintf("unable to extract postgres archive: mkdir %s: invalid argument", op),
)
}
38 changes: 26 additions & 12 deletions embedded_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ import (
"path/filepath"
"runtime"
"strings"
"sync"
)

var mu sync.Mutex

// EmbeddedPostgres maintains all configuration and runtime functions for maintaining the lifecycle of one Postgres process.
type EmbeddedPostgres struct {
config Config
Expand Down Expand Up @@ -92,20 +95,11 @@ func (ep *EmbeddedPostgres) Start() error {
ep.config.binariesPath = ep.config.runtimePath
}

_, binDirErr := os.Stat(filepath.Join(ep.config.binariesPath, "bin"))
if os.IsNotExist(binDirErr) {
if !cacheExists {
if err := ep.remoteFetchStrategy(); err != nil {
return err
}
}

if err := decompressTarXz(defaultTarReader, cacheLocation, ep.config.binariesPath); err != nil {
return err
}
if err := ep.downloadAndExtractBinary(cacheExists, cacheLocation); err != nil {
return err
}

if err := os.MkdirAll(ep.config.runtimePath, 0755); err != nil {
if err := os.MkdirAll(ep.config.runtimePath, os.ModePerm); err != nil {
return fmt.Errorf("unable to create runtime directory %s with error: %s", ep.config.runtimePath, err)
}

Expand Down Expand Up @@ -148,6 +142,26 @@ func (ep *EmbeddedPostgres) Start() error {
return nil
}

func (ep *EmbeddedPostgres) downloadAndExtractBinary(cacheExists bool, cacheLocation string) error {
// lock to prevent collisions with duplicate downloads
mu.Lock()
defer mu.Unlock()

_, binDirErr := os.Stat(filepath.Join(ep.config.binariesPath, "bin"))
if os.IsNotExist(binDirErr) {
if !cacheExists {
if err := ep.remoteFetchStrategy(); err != nil {
return err
}
}

if err := decompressTarXz(defaultTarReader, cacheLocation, ep.config.binariesPath); err != nil {
return err
}
}
return nil
}

func (ep *EmbeddedPostgres) cleanDataDirectoryAndInit() error {
if err := os.RemoveAll(ep.config.dataPath); err != nil {
return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err)
Expand Down
63 changes: 62 additions & 1 deletion embedded_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_DefaultConfig(t *testing.T) {
Expand Down Expand Up @@ -99,7 +100,7 @@ func Test_ErrorWhenUnableToUnArchiveFile_WrongFormat(t *testing.T) {
}
}

assert.EqualError(t, err, fmt.Sprintf(`unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories`, jarFile, filepath.Join(filepath.Dir(jarFile), "extracted")))
assert.EqualError(t, err, fmt.Sprintf(`unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories, xz: file format not recognized`, jarFile, filepath.Join(filepath.Dir(jarFile), "extracted")))
}

func Test_ErrorWhenUnableToInitDatabase(t *testing.T) {
Expand Down Expand Up @@ -355,6 +356,66 @@ func Test_CustomLocaleConfig(t *testing.T) {
}
}

func Test_ConcurrentStart(t *testing.T) {
var wg sync.WaitGroup

database := NewDatabase()
cacheLocation, _ := database.cacheLocator()
err := os.RemoveAll(cacheLocation)
require.NoError(t, err)

port := 5432
for i := 1; i <= 3; i++ {
port = port + 1
wg.Add(1)

go func(p int) {
defer wg.Done()
tempDir, err := os.MkdirTemp("", "embedded_postgres_test")
if err != nil {
panic(err)
}

defer func() {
if err := os.RemoveAll(tempDir); err != nil {
panic(err)
}
}()

database := NewDatabase(DefaultConfig().
RuntimePath(tempDir).
Port(uint32(p)))

if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database)
}

db, err := sql.Open(
"postgres",
fmt.Sprintf("host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", p),
)
if err != nil {
shutdownDBAndFail(t, err, database)
}

if err = db.Ping(); err != nil {
shutdownDBAndFail(t, err, database)
}

if err := db.Close(); err != nil {
shutdownDBAndFail(t, err, database)
}

if err := database.Stop(); err != nil {
shutdownDBAndFail(t, err, database)
}

}(port)
}

wg.Wait()
}

func Test_CanStartAndStopTwice(t *testing.T) {
database := NewDatabase()

Expand Down
4 changes: 2 additions & 2 deletions platform-test/platform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ func Test_AllMajorVersions(t *testing.T) {
}

func shutdownDBAndFail(t *testing.T, err error, db *embeddedpostgres.EmbeddedPostgres, version embeddedpostgres.PostgresVersion) {
if err := db.Stop(); err != nil {
t.Fatalf("Failed for version %s with error %s", version, err)
if err2 := db.Stop(); err2 != nil {
t.Fatalf("Failed for version %s with error %s, original error %s", version, err2, err)
}

t.Fatalf("Failed for version %s with error %s", version, err)
Expand Down
Loading