Clarify the use of a global shared database handle.

Replace the global db pointer with a hook that test cases can call to access
the shared database. The result is still global, but the name is a little more
self-documenting, and the use of call syntax makes it clear where we are
accessing something with more than trivial structure.
This commit is contained in:
M. J. Fromberger
2021-08-30 11:48:18 -07:00
parent 8667d60bd6
commit 5970fb3cb8

View File

@@ -33,7 +33,9 @@ var (
doPauseAtExit = flag.Bool("pause-at-exit", false,
"If true, pause the test until interrupted at shutdown, to allow debugging")
db *sql.DB
// A hook that test cases can call to obtain the shared database instance
// used for testing the sink. This is initialized in TestMain (see below).
testDB func() *sql.DB
)
const (
@@ -86,6 +88,7 @@ func TestMain(m *testing.M) {
// Connect to the database, clear any leftover data, and install the
// indexing schema.
conn := fmt.Sprintf(dsn, user, password, resource.GetPort(port+"/tcp"), dbName)
var db *sql.DB
if err := pool.Retry(func() error {
sink, err := NewEventSink(conn, chainID)
@@ -109,6 +112,9 @@ func TestMain(m *testing.M) {
log.Fatalf("Applying schema: %v", err)
}
// Set up the hook for tests to get the shared database handle.
testDB = func() *sql.DB { return db }
// Run the selected test cases.
code := m.Run()
@@ -130,12 +136,12 @@ func TestMain(m *testing.M) {
}
func TestType(t *testing.T) {
psqlSink := &EventSink{store: db, chainID: chainID}
psqlSink := &EventSink{store: testDB(), chainID: chainID}
assert.Equal(t, indexer.PSQL, psqlSink.Type())
}
func TestBlockFuncs(t *testing.T) {
indexer := &EventSink{store: db, chainID: chainID}
indexer := &EventSink{store: testDB(), chainID: chainID}
require.NoError(t, indexer.IndexBlockEvents(newTestBlockHeader()))
verifyBlock(t, 1)
@@ -156,7 +162,7 @@ func TestBlockFuncs(t *testing.T) {
}
func TestTxFuncs(t *testing.T) {
indexer := &EventSink{store: db, chainID: chainID}
indexer := &EventSink{store: testDB(), chainID: chainID}
txResult := txResultWithEvents([]abci.Event{
makeIndexedEvent("account.number", "1"),
@@ -189,7 +195,7 @@ func TestTxFuncs(t *testing.T) {
}
func TestStop(t *testing.T) {
indexer := &EventSink{store: db}
indexer := &EventSink{store: testDB()}
require.NoError(t, indexer.Stop())
}
@@ -259,7 +265,7 @@ func txResultWithEvents(events []abci.Event) *abci.TxResult {
func loadTxResult(hash []byte) (*abci.TxResult, error) {
hashString := fmt.Sprintf("%X", hash)
var resultData []byte
if err := db.QueryRow(`
if err := testDB().QueryRow(`
SELECT tx_result FROM `+tableTxResults+` WHERE tx_hash = $1;
`, hashString).Scan(&resultData); err != nil {
return nil, fmt.Errorf("lookup transaction for hash %q failed: %v", hashString, err)
@@ -274,7 +280,7 @@ SELECT tx_result FROM `+tableTxResults+` WHERE tx_hash = $1;
}
func verifyTimeStamp(tableName string) error {
return db.QueryRow(fmt.Sprintf(`
return testDB().QueryRow(fmt.Sprintf(`
SELECT DISTINCT %[1]s.created_at
FROM %[1]s
WHERE %[1]s.created_at >= $1;
@@ -283,7 +289,7 @@ SELECT DISTINCT %[1]s.created_at
func verifyBlock(t *testing.T, height int64) {
// Check that the blocks table contains an entry for this height.
if err := db.QueryRow(`
if err := testDB().QueryRow(`
SELECT height FROM `+tableBlocks+` WHERE height = $1;
`, height).Err(); err == sql.ErrNoRows {
t.Errorf("No block found for height=%d", height)
@@ -292,7 +298,7 @@ SELECT height FROM `+tableBlocks+` WHERE height = $1;
}
// Verify the presence of begin_block and end_block events.
if err := db.QueryRow(`
if err := testDB().QueryRow(`
SELECT type, height, chain_id FROM `+viewBlockEvents+`
WHERE height = $1 AND type = $2 AND chain_id = $3;
`, height, types.EventTypeBeginBlock, chainID).Err(); err == sql.ErrNoRows {
@@ -301,7 +307,7 @@ SELECT type, height, chain_id FROM `+viewBlockEvents+`
t.Fatalf("Database query failed: %v", err)
}
if err := db.QueryRow(`
if err := testDB().QueryRow(`
SELECT type, height, chain_id FROM `+viewBlockEvents+`
WHERE height = $1 AND type = $2 AND chain_id = $3;
`, height, types.EventTypeEndBlock, chainID).Err(); err == sql.ErrNoRows {