Skip to content

feat(analyzer): Automatically create databases #3376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ install:
test:
go test ./...

test-managed:
MYSQL_SERVER_URI="invalid" POSTGRESQL_SERVER_URI="postgres://postgres:mysecretpassword@localhost:5432/postgres" go test ./...

vet:
go vet ./...

Expand Down
2 changes: 2 additions & 0 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,11 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
}
return nil, true
}

if parserOpts.Debug.DumpCatalog {
debug.Dump(c.Catalog())
}

if err := c.ParseQueries(sql.Queries, parserOpts); err != nil {
fmt.Fprintf(stderr, "# package %s\n", name)
if parserErr, ok := err.(*multierr.Error); ok {
Expand Down
2 changes: 1 addition & 1 deletion internal/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err
if conf.Database != nil {
if conf.Analyzer.Database == nil || *conf.Analyzer.Database {
c.analyzer = analyzer.Cached(
pganalyze.New(c.client, *conf.Database),
pganalyze.New(c.client, combo.Global.Servers, *conf.Database),
combo.Global,
*conf.Database,
)
Expand Down
7 changes: 7 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ const (
type Config struct {
Version string `json:"version" yaml:"version"`
Cloud Cloud `json:"cloud" yaml:"cloud"`
Servers []Server `json:"servers" yaml:"servers"`
SQL []SQL `json:"sql" yaml:"sql"`
Overrides Overrides `json:"overrides,omitempty" yaml:"overrides"`
Plugins []Plugin `json:"plugins" yaml:"plugins"`
Expand All @@ -69,6 +70,7 @@ type Config struct {
type Database struct {
URI string `json:"uri" yaml:"uri"`
Managed bool `json:"managed" yaml:"managed"`
Auto bool `json:"auto" yaml:"auto"`
}

type Cloud struct {
Expand All @@ -78,6 +80,11 @@ type Cloud struct {
AuthToken string `json:"-" yaml:"-"`
}

type Server struct {
Name string `json:"name" yaml:"name"`
URI string `json:"uri" yaml:"uri"`
}

type Plugin struct {
Name string `json:"name" yaml:"name"`
Env []string `json:"env" yaml:"env"`
Expand Down
1 change: 1 addition & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func TestInvalidConfig(t *testing.T) {
Database: &Database{
URI: "",
Managed: false,
Auto: false,
},
}},
})
Expand Down
6 changes: 5 additions & 1 deletion internal/config/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ package config
func Validate(c *Config) error {
for _, sql := range c.SQL {
if sql.Database != nil {
if sql.Database.URI == "" && !sql.Database.Managed {
switch {
case sql.Database.URI != "":
case sql.Database.Managed:
case sql.Database.Auto:
default:
return ErrInvalidDatabase
}
}
Expand Down
21 changes: 13 additions & 8 deletions internal/endtoend/endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,26 @@ func TestReplay(t *testing.T) {
"managed-db": {
Mutate: func(t *testing.T, path string) func(*config.Config) {
return func(c *config.Config) {
c.Servers = []config.Server{
{
Name: "postgres",
URI: local.PostgreSQLServer(),
},

{
Name: "mysql",
URI: local.MySQLServer(),
},
}
for i := range c.SQL {
files := []string{}
for _, s := range c.SQL[i].Schema {
files = append(files, filepath.Join(path, s))
}
switch c.SQL[i].Engine {
case config.EnginePostgreSQL:
uri := local.ReadOnlyPostgreSQL(t, files)
c.SQL[i].Database = &config.Database{
URI: uri,
Auto: true,
}
case config.EngineMySQL:
uri := local.MySQL(t, files)
c.SQL[i].Database = &config.Database{
URI: uri,
Auto: true,
}
default:
// pass
Expand Down
50 changes: 37 additions & 13 deletions internal/engine/postgresql/analyzer/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@ import (
"context"
"errors"
"fmt"
"hash/fnv"
"io"
"strings"
"sync"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/sync/singleflight"

core "github.com/sqlc-dev/sqlc/internal/analysis"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/opts"
"github.com/sqlc-dev/sqlc/internal/pgx/poolcache"
pb "github.com/sqlc-dev/sqlc/internal/quickdb/v1"
"github.com/sqlc-dev/sqlc/internal/shfmt"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
Expand All @@ -22,22 +26,28 @@ import (
)

type Analyzer struct {
db config.Database
client pb.QuickClient
pool *pgxpool.Pool
dbg opts.Debug
replacer *shfmt.Replacer
formats sync.Map
columns sync.Map
tables sync.Map
db config.Database
client pb.QuickClient
pool *pgxpool.Pool
dbg opts.Debug
replacer *shfmt.Replacer
formats sync.Map
columns sync.Map
tables sync.Map
servers []config.Server
serverCache *poolcache.Cache
flight singleflight.Group
}

func New(client pb.QuickClient, db config.Database) *Analyzer {
func New(client pb.QuickClient, servers []config.Server, db config.Database) *Analyzer {
return &Analyzer{
db: db,
dbg: opts.DebugFromEnv(),
client: client,
replacer: shfmt.NewReplacer(nil),
// TODO: Pick first
servers: servers,
db: db,
dbg: opts.DebugFromEnv(),
client: client,
replacer: shfmt.NewReplacer(nil),
serverCache: poolcache.New(),
}
}

Expand Down Expand Up @@ -99,6 +109,14 @@ type columnKey struct {
Attr uint16
}

func (a *Analyzer) fnv(migrations []string) string {
h := fnv.New64()
for _, query := range migrations {
io.WriteString(h, query)
}
return fmt.Sprintf("%x", h.Sum(nil))
}

// Cache these types in memory
func (a *Analyzer) columnInfo(ctx context.Context, field pgconn.FieldDescription) (*pgColumn, error) {
key := columnKey{field.TableOID, field.TableAttributeNumber}
Expand Down Expand Up @@ -211,6 +229,12 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
uri = edb.Uri
} else if a.dbg.OnlyManagedDatabases {
return nil, fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed")
} else if a.db.Auto {
var err error
uri, err = a.createDb(ctx, migrations)
if err != nil {
return nil, err
}
} else {
uri = a.replacer.Replace(a.db.URI)
}
Expand Down
68 changes: 68 additions & 0 deletions internal/engine/postgresql/analyzer/createdb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package analyzer

import (
"context"
"fmt"
"log/slog"
"net/url"
"strings"

"github.com/jackc/pgx/v5"
)

func (a *Analyzer) createDb(ctx context.Context, migrations []string) (string, error) {
hash := a.fnv(migrations)
name := fmt.Sprintf("sqlc_%s", hash)

serverUri := a.replacer.Replace(a.servers[0].URI)
pool, err := a.serverCache.Open(ctx, serverUri)
if err != nil {
return "", err
}

uri, err := url.Parse(serverUri)
if err != nil {
return "", err
}
uri.Path = name

key := uri.String()
_, err, _ = a.flight.Do(key, func() (interface{}, error) {
// TODO: Use a parameterized query
row := pool.QueryRow(ctx,
fmt.Sprintf(`SELECT datname FROM pg_database WHERE datname = '%s'`, name))

var datname string
if err := row.Scan(&datname); err == nil {
slog.Info("database exists", "name", name)
return nil, nil
}

slog.Info("creating database", "name", name)
if _, err := pool.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, name)); err != nil {
return nil, err
}

conn, err := pgx.Connect(ctx, uri.String())
if err != nil {
return nil, fmt.Errorf("connect %s: %s", name, err)
}
defer conn.Close(ctx)

for _, q := range migrations {
if len(strings.TrimSpace(q)) == 0 {
continue
}
if _, err := conn.Exec(ctx, q); err != nil {
return nil, fmt.Errorf("%s: %s", q, err)
}
}
return nil, nil
})

if err != nil {
return "", err
}

return key, err
}
27 changes: 18 additions & 9 deletions internal/pgx/poolcache/poolcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,22 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
)

var lock sync.RWMutex
var pools = map[string]*pgxpool.Pool{}
type Cache struct {
lock sync.RWMutex
pools map[string]*pgxpool.Pool
}

func New() *Cache {
return &Cache{
pools: map[string]*pgxpool.Pool{},
}
}

func New(ctx context.Context, uri string) (*pgxpool.Pool, error) {
lock.RLock()
existing, found := pools[uri]
lock.RUnlock()
// Should only be used in testing contexts
func (c *Cache) Open(ctx context.Context, uri string) (*pgxpool.Pool, error) {
c.lock.RLock()
existing, found := c.pools[uri]
c.lock.RUnlock()

if found {
return existing, nil
Expand All @@ -24,9 +33,9 @@ func New(ctx context.Context, uri string) (*pgxpool.Pool, error) {
return nil, err
}

lock.Lock()
pools[uri] = pool
lock.Unlock()
c.lock.Lock()
c.pools[uri] = pool
c.lock.Unlock()

return pool, nil
}
4 changes: 4 additions & 0 deletions internal/sqltest/local/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import (
var mysqlSync sync.Once
var mysqlPool *sql.DB

func MySQLServer() string {
return os.Getenv("MYSQL_SERVER_URI")
}

func MySQL(t *testing.T, migrations []string) string {
ctx := context.Background()
t.Helper()
Expand Down
7 changes: 6 additions & 1 deletion internal/sqltest/local/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
)

var flight singleflight.Group
var cache = poolcache.New()

func PostgreSQL(t *testing.T, migrations []string) string {
return postgreSQL(t, migrations, true)
Expand All @@ -27,6 +28,10 @@ func ReadOnlyPostgreSQL(t *testing.T, migrations []string) string {
return postgreSQL(t, migrations, false)
}

func PostgreSQLServer() string {
return os.Getenv("POSTGRESQL_SERVER_URI")
}

func postgreSQL(t *testing.T, migrations []string, rw bool) string {
ctx := context.Background()
t.Helper()
Expand All @@ -36,7 +41,7 @@ func postgreSQL(t *testing.T, migrations []string, rw bool) string {
t.Skip("POSTGRESQL_SERVER_URI is empty")
}

postgresPool, err := poolcache.New(ctx, dburi)
postgresPool, err := cache.Open(ctx, dburi)
if err != nil {
t.Fatalf("PostgreSQL pool creation failed: %s", err)
}
Expand Down
Loading