Skip to content

Add a way to connect through UNIX socket. #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package embeddedpostgres
import (
"fmt"
"io"
"net/url"
"os"
"time"
)
Expand All @@ -11,6 +12,8 @@ import (
type Config struct {
version PostgresVersion
port uint32
useUnixSocket bool
unixSocketDirectory string
database string
username string
password string
Expand Down Expand Up @@ -38,6 +41,8 @@ func DefaultConfig() Config {
return Config{
version: V16,
port: 5432,
useUnixSocket: false,
unixSocketDirectory: "/tmp/",
database: "postgres",
username: "postgres",
password: "postgres",
Expand All @@ -59,6 +64,17 @@ func (c Config) Port(port uint32) Config {
return c
}

// WithoutTcp makes Posgres listen on a UNIX socket instead of opening a TCP port.
func (c Config) WithoutTcp() Config {
c.useUnixSocket = true
return c
}

func (c Config) WithUnixSocketDirectory(dir string) Config {
c.unixSocketDirectory = dir
return c
}

// Database sets the database name that will be created.
func (c Config) Database(database string) Config {
c.database = database
Expand Down Expand Up @@ -145,7 +161,23 @@ func (c Config) BinaryRepositoryURL(binaryRepositoryURL string) Config {
}

func (c Config) GetConnectionURL() string {
return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", c.username, c.password, "localhost", c.port, c.database)
u := &url.URL{
Scheme: "postgresql",
User: url.UserPassword(c.username, c.password),
Path: "/" + c.database,
}

if c.useUnixSocket {
u.Host = fmt.Sprintf(":%d", c.port)

q := url.Values{}
q.Set("host", c.unixSocketDirectory)
u.RawQuery = q.Encode()
} else {
u.Host = fmt.Sprintf("localhost:%d", c.port)
}

return u.String()
}

// PostgresVersion represents the semantic version used to fetch and run the Postgres process.
Expand Down
32 changes: 32 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package embeddedpostgres

import (
"testing"
)

func TestGetConnectionURL(t *testing.T) {
config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass")
expect := "postgresql://myuser:mypass@localhost:5432/mydb"

if got := config.GetConnectionURL(); got != expect {
t.Errorf("expected \"%s\" got \"%s\"", expect, got)
}
}

func TestGetConnectionURLWithUnixSocket(t *testing.T) {
config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass").WithoutTcp()
expect := "postgresql://myuser:mypass@:5432/mydb?host=%2Ftmp%2F"

if got := config.GetConnectionURL(); got != expect {
t.Errorf("expected \"%s\" got \"%s\"", expect, got)
}
}

func TestGetConnectionURLWithUnixSocketInCustomDir(t *testing.T) {
config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass").WithoutTcp().WithUnixSocketDirectory("/path/to/socks")
expect := "postgresql://myuser:mypass@:5432/mydb?host=%2Fpath%2Fto%2Fsocks"

if got := config.GetConnectionURL(); got != expect {
t.Errorf("expected \"%s\" got \"%s\"", expect, got)
}
}
24 changes: 23 additions & 1 deletion embedded_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ func (ep *EmbeddedPostgres) Start() error {
ep.started = true

if !reuseData {
if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil {
host := "localhost"
if ep.config.useUnixSocket {
host = ep.config.unixSocketDirectory
}

if err := ep.createDatabase(host, ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil {
if stopErr := stopPostgres(ep); stopErr != nil {
return fmt.Errorf("unable to stop database caused by error %s", err)
}
Expand Down Expand Up @@ -167,6 +172,10 @@ func (ep *EmbeddedPostgres) downloadAndExtractBinary(cacheExists bool, cacheLoca
return nil
}

func (ep *EmbeddedPostgres) GetConnectionURL() string {
return ep.config.GetConnectionURL()
}

func (ep *EmbeddedPostgres) cleanDataDirectoryAndInit() error {
if err := os.RemoveAll(ep.config.dataPath); err != nil {
return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err)
Expand Down Expand Up @@ -210,7 +219,20 @@ func encodeOptions(port uint32, parameters map[string]string) string {
}

func startPostgres(ep *EmbeddedPostgres) error {
if ep.config.startParameters == nil {
ep.config.startParameters = make(map[string]string)
}

if ep.config.useUnixSocket {
ep.config.startParameters["listen_addresses"] = ""
ep.config.startParameters["unix_socket_directories"] = ep.config.unixSocketDirectory
}

postgresBinary := filepath.Join(ep.config.binariesPath, "bin/pg_ctl")
fmt.Println(postgresBinary, "start", "-w",
"-D", ep.config.dataPath,
"-o", encodeOptions(ep.config.port, ep.config.startParameters))

postgresProcess := exec.Command(postgresBinary, "start", "-w",
"-D", ep.config.dataPath,
"-o", encodeOptions(ep.config.port, ep.config.startParameters))
Expand Down
35 changes: 33 additions & 2 deletions embedded_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func Test_ErrorWhenUnableToCreateDatabase(t *testing.T) {
RuntimePath(extractPath).
StartTimeout(10 * time.Second))

database.createDatabase = func(port uint32, username, password, database string) error {
database.createDatabase = func(host string, port uint32, username, password, database string) error {
return errors.New("ah noes")
}

Expand All @@ -176,7 +176,7 @@ func Test_TimesOutWhenCannotStart(t *testing.T) {
Database("something-fancy").
StartTimeout(500 * time.Millisecond))

database.createDatabase = func(port uint32, username, password, database string) error {
database.createDatabase = func(host string, port uint32, username, password, database string) error {
return nil
}

Expand Down Expand Up @@ -802,3 +802,34 @@ func Test_RunningInParallel(t *testing.T) {

waitGroup.Wait()
}

func Test_RunOnUnixSocket(t *testing.T) {
database := NewDatabase(DefaultConfig().Port(9876).WithoutTcp())
if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database)
}

defer database.Stop()

if _, err := os.Stat("/tmp/.s.PGSQL.9876"); err != nil {
shutdownDBAndFail(t, err, database)
}
}

func Test_RunOnUnixSocketOnCustomPath(t *testing.T) {
tempPath, err := os.MkdirTemp("", "custom_dir_socks")
if err != nil {
panic(err)
}

database := NewDatabase(DefaultConfig().Port(9876).WithoutTcp().WithUnixSocketDirectory(tempPath))
if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database)
}

defer database.Stop()

if _, err := os.Stat(fmt.Sprintf("%s/.s.PGSQL.9876", tempPath)); err != nil {
shutdownDBAndFail(t, err, database)
}
}
53 changes: 47 additions & 6 deletions examples/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing"

Expand All @@ -27,7 +28,7 @@ func Test_GooseMigrations(t *testing.T) {
}
}()

db, err := connect()
db, err := connect(database.GetConnectionURL())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -57,7 +58,7 @@ func Test_ZapioLogger(t *testing.T) {
}
}()

db, err := connect()
db, err := connect(database.GetConnectionURL())
if err != nil {
t.Fatal(err)
}
Expand All @@ -79,7 +80,36 @@ func Test_Sqlx_SelectOne(t *testing.T) {
}
}()

db, err := connect()
db, err := connect(database.GetConnectionURL())
if err != nil {
t.Fatal(err)
}

rows := make([]int32, 0)

err = db.Select(&rows, "SELECT 1")
if err != nil {
t.Fatal(err)
}

if len(rows) != 1 {
t.Fatal("Expected one row returned")
}
}

func Test_UnixSocket_Sqlx_SelectOne(t *testing.T) {
database := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig().WithoutTcp())
if err := database.Start(); err != nil {
t.Fatal(err)
}

defer func() {
if err := database.Stop(); err != nil {
t.Fatal(err)
}
}()

db, err := connect(database.GetConnectionURL())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -108,7 +138,7 @@ func Test_ManyTestsAgainstOneDatabase(t *testing.T) {
}
}()

db, err := connect()
db, err := connect(database.GetConnectionURL())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -188,7 +218,18 @@ func Test_SimpleHttpWebApp(t *testing.T) {
}
}

func connect() (*sqlx.DB, error) {
db, err := sqlx.Connect("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
func connect(u string) (*sqlx.DB, error) {
parsed, err := url.Parse(u)
if err != nil {
return nil, err
}

q := parsed.Query()
if q.Get("sock") == "" {
q.Set("sslmode", "disable")
}
parsed.RawQuery = q.Encode()

db, err := sqlx.Connect("postgres", parsed.String())
return db, err
}
22 changes: 14 additions & 8 deletions prepare_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const (
)

type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error
type createDatabase func(port uint32, username, password, database string) error
type createDatabase func(host string, port uint32, username, password, database string) error

func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error {
passwordFile, err := createPasswordFile(runtimePath, password)
Expand Down Expand Up @@ -71,12 +71,12 @@ func createPasswordFile(runtimePath, password string) (string, error) {
return passwordFileLocation, nil
}

func defaultCreateDatabase(port uint32, username, password, database string) (err error) {
func defaultCreateDatabase(host string, port uint32, username, password, database string) (err error) {
if database == "postgres" {
return nil
}

conn, err := openDatabaseConnection(port, username, password, "postgres")
conn, err := openDatabaseConnection(host, port, username, password, "postgres")
if err != nil {
return errorCustomDatabase(database, err)
}
Expand Down Expand Up @@ -120,7 +120,12 @@ func healthCheckDatabaseOrTimeout(config Config) error {

go func() {
for timeout.Err() == nil {
if err := healthCheckDatabase(config.port, config.database, config.username, config.password); err != nil {
host := "localhost"
if config.useUnixSocket {
host = config.unixSocketDirectory
}

if err := healthCheckDatabase(host, config.port, config.database, config.username, config.password); err != nil {
continue
}
healthCheckSignal <- true
Expand All @@ -137,8 +142,8 @@ func healthCheckDatabaseOrTimeout(config Config) error {
}
}

func healthCheckDatabase(port uint32, database, username, password string) (err error) {
conn, err := openDatabaseConnection(port, username, password, database)
func healthCheckDatabase(host string, port uint32, database, username, password string) (err error) {
conn, err := openDatabaseConnection(host, port, username, password, database)
if err != nil {
return err
}
Expand All @@ -155,8 +160,9 @@ func healthCheckDatabase(port uint32, database, username, password string) (err
return nil
}

func openDatabaseConnection(port uint32, username string, password string, database string) (*pq.Connector, error) {
conn, err := pq.NewConnector(fmt.Sprintf("host=localhost port=%d user=%s password=%s dbname=%s sslmode=disable",
func openDatabaseConnection(host string, port uint32, username string, password string, database string) (*pq.Connector, error) {
conn, err := pq.NewConnector(fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
host,
port,
username,
password,
Expand Down
12 changes: 9 additions & 3 deletions prepare_database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@ func Test_defaultInitDatabase_PwFileRemoved(t *testing.T) {
}

func Test_defaultCreateDatabase_ErrorWhenSQLOpenError(t *testing.T) {
err := defaultCreateDatabase(1234, "user client_encoding=lol", "password", "database")
err := defaultCreateDatabase("localhost", 1234, "user client_encoding=lol", "password", "database")

assert.EqualError(t, err, "unable to connect to create database with custom name database with the following error: client_encoding must be absent or 'UTF8'")
}

func Test_defaultCreateDatabase_ErrorWhenSQLOpenError_UnixSocket(t *testing.T) {
err := defaultCreateDatabase("/tmp", 1234, "user client_encoding=lol", "password", "database")

assert.EqualError(t, err, "unable to connect to create database with custom name database with the following error: client_encoding must be absent or 'UTF8'")
}
Expand Down Expand Up @@ -165,13 +171,13 @@ func Test_defaultCreateDatabase_ErrorWhenQueryError(t *testing.T) {
}
}()

err := defaultCreateDatabase(9831, "postgres", "postgres", "b33r")
err := defaultCreateDatabase("localhost", 9831, "postgres", "postgres", "b33r")

assert.EqualError(t, err, `unable to connect to create database with custom name b33r with the following error: pq: database "b33r" already exists`)
}

func Test_healthCheckDatabase_ErrorWhenSQLConnectingError(t *testing.T) {
err := healthCheckDatabase(1234, "tom client_encoding=lol", "more", "b33r")
err := healthCheckDatabase("localhost", 1234, "tom client_encoding=lol", "more", "b33r")

assert.EqualError(t, err, "client_encoding must be absent or 'UTF8'")
}
Expand Down
12 changes: 0 additions & 12 deletions test_config.go

This file was deleted.

Loading