mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-04 09:48:19 +01:00
131 lines
2.5 KiB
Go
131 lines
2.5 KiB
Go
// +build kvdb_postgres
|
|
|
|
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"testing"
|
|
|
|
"github.com/btcsuite/btcwallet/walletdb"
|
|
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
const (
|
|
testDsn = "postgres://postgres:postgres@localhost:9876/postgres?sslmode=disable"
|
|
prefix = "test"
|
|
)
|
|
|
|
func clearTestDb(t *testing.T) {
|
|
dbConn, err := sql.Open("pgx", testDsn)
|
|
require.NoError(t, err)
|
|
|
|
_, err = dbConn.ExecContext(context.Background(), "DROP SCHEMA IF EXISTS public CASCADE;")
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func openTestDb(t *testing.T) *db {
|
|
clearTestDb(t)
|
|
|
|
db, err := newPostgresBackend(
|
|
context.Background(),
|
|
&Config{
|
|
Dsn: testDsn,
|
|
},
|
|
prefix,
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
return db
|
|
}
|
|
|
|
type fixture struct {
|
|
t *testing.T
|
|
tempDir string
|
|
postgres *embeddedpostgres.EmbeddedPostgres
|
|
}
|
|
|
|
func NewFixture(t *testing.T) *fixture {
|
|
postgres := embeddedpostgres.NewDatabase(
|
|
embeddedpostgres.DefaultConfig().
|
|
Port(9876))
|
|
|
|
err := postgres.Start()
|
|
require.NoError(t, err)
|
|
|
|
return &fixture{
|
|
t: t,
|
|
postgres: postgres,
|
|
}
|
|
}
|
|
|
|
func (b *fixture) Cleanup() {
|
|
b.postgres.Stop()
|
|
}
|
|
|
|
func (b *fixture) NewBackend() walletdb.DB {
|
|
clearTestDb(b.t)
|
|
db := openTestDb(b.t)
|
|
|
|
return db
|
|
}
|
|
|
|
func (b *fixture) Dump() map[string]interface{} {
|
|
dbConn, err := sql.Open("pgx", testDsn)
|
|
require.NoError(b.t, err)
|
|
|
|
rows, err := dbConn.Query(
|
|
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'",
|
|
)
|
|
require.NoError(b.t, err)
|
|
|
|
var tables []string
|
|
for rows.Next() {
|
|
var table string
|
|
err := rows.Scan(&table)
|
|
require.NoError(b.t, err)
|
|
|
|
tables = append(tables, table)
|
|
}
|
|
|
|
result := make(map[string]interface{})
|
|
|
|
for _, table := range tables {
|
|
rows, err := dbConn.Query("SELECT * FROM " + table)
|
|
require.NoError(b.t, err)
|
|
|
|
cols, err := rows.Columns()
|
|
require.NoError(b.t, err)
|
|
colCount := len(cols)
|
|
|
|
var tableRows []map[string]interface{}
|
|
for rows.Next() {
|
|
values := make([]interface{}, colCount)
|
|
valuePtrs := make([]interface{}, colCount)
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
err := rows.Scan(valuePtrs...)
|
|
require.NoError(b.t, err)
|
|
|
|
tableData := make(map[string]interface{})
|
|
for i, v := range values {
|
|
// Cast byte slices to string to keep the
|
|
// expected database contents in test code more
|
|
// readable.
|
|
if ar, ok := v.([]uint8); ok {
|
|
v = string(ar)
|
|
}
|
|
tableData[cols[i]] = v
|
|
}
|
|
|
|
tableRows = append(tableRows, tableData)
|
|
}
|
|
|
|
result[table] = tableRows
|
|
}
|
|
|
|
return result
|
|
}
|