lnd/kvdb/postgres/fixture.go
2021-09-21 10:44:18 +02:00

131 lines
2.5 KiB
Go

// +build kvdb_postgres
package postgres
import (
"context"
"database/sql"
"testing"
"github.com/btcsuite/btcwallet/walletdb"
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
"github.com/stretchr/testify/require"
)
const (
testDsn = "postgres://postgres:postgres@localhost:9876/postgres?sslmode=disable"
prefix = "test"
)
func clearTestDb(t *testing.T) {
dbConn, err := sql.Open("pgx", testDsn)
require.NoError(t, err)
_, err = dbConn.ExecContext(context.Background(), "DROP SCHEMA IF EXISTS public CASCADE;")
require.NoError(t, err)
}
func openTestDb(t *testing.T) *db {
clearTestDb(t)
db, err := newPostgresBackend(
context.Background(),
&Config{
Dsn: testDsn,
},
prefix,
)
require.NoError(t, err)
return db
}
type fixture struct {
t *testing.T
tempDir string
postgres *embeddedpostgres.EmbeddedPostgres
}
func NewFixture(t *testing.T) *fixture {
postgres := embeddedpostgres.NewDatabase(
embeddedpostgres.DefaultConfig().
Port(9876))
err := postgres.Start()
require.NoError(t, err)
return &fixture{
t: t,
postgres: postgres,
}
}
func (b *fixture) Cleanup() {
b.postgres.Stop()
}
func (b *fixture) NewBackend() walletdb.DB {
clearTestDb(b.t)
db := openTestDb(b.t)
return db
}
func (b *fixture) Dump() map[string]interface{} {
dbConn, err := sql.Open("pgx", testDsn)
require.NoError(b.t, err)
rows, err := dbConn.Query(
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'",
)
require.NoError(b.t, err)
var tables []string
for rows.Next() {
var table string
err := rows.Scan(&table)
require.NoError(b.t, err)
tables = append(tables, table)
}
result := make(map[string]interface{})
for _, table := range tables {
rows, err := dbConn.Query("SELECT * FROM " + table)
require.NoError(b.t, err)
cols, err := rows.Columns()
require.NoError(b.t, err)
colCount := len(cols)
var tableRows []map[string]interface{}
for rows.Next() {
values := make([]interface{}, colCount)
valuePtrs := make([]interface{}, colCount)
for i := range values {
valuePtrs[i] = &values[i]
}
err := rows.Scan(valuePtrs...)
require.NoError(b.t, err)
tableData := make(map[string]interface{})
for i, v := range values {
// Cast byte slices to string to keep the
// expected database contents in test code more
// readable.
if ar, ok := v.([]uint8); ok {
v = string(ar)
}
tableData[cols[i]] = v
}
tableRows = append(tableRows, tableData)
}
result[table] = tableRows
}
return result
}