40 lines
706 B
Go
40 lines
706 B
Go
package sqltest
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
func MultiExec(e sqlx.Execer, query string) error {
|
|
stmts := strings.Split(query, ";\n")
|
|
if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 {
|
|
stmts = stmts[:len(stmts)-1]
|
|
}
|
|
for _, s := range stmts {
|
|
_, err := e.Exec(s)
|
|
if err != nil {
|
|
return fmt.Errorf("%s\n%s", err, s)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func RunWithSchema(db *sqlx.DB, create, drop string, t *testing.T, test func(db *sqlx.DB, t *testing.T)) {
|
|
defer func() {
|
|
err := MultiExec(db, drop)
|
|
if err != nil {
|
|
t.Fatalf("%s", err.Error())
|
|
}
|
|
}()
|
|
|
|
err := MultiExec(db, create)
|
|
if err != nil {
|
|
t.Fatalf("%s", err.Error())
|
|
}
|
|
|
|
test(db, t)
|
|
}
|