116 lines
2.2 KiB
Go
116 lines
2.2 KiB
Go
package sqly
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
const initDBQuery = `
|
|
CREATE OR REPLACE FUNCTION update_modified_column()
|
|
RETURNS TRIGGER AS $$
|
|
BEGIN NEW.updated = now(); RETURN NEW; END; $$ language 'plpgsql';
|
|
`
|
|
|
|
// InitDB add some function to the database and template
|
|
func InitDB(db *sqlx.DB) error {
|
|
err := MultiExec(db, initDBQuery)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// BaseTable have to be embeded in all your struct which reflect a table
|
|
type BaseTable struct {
|
|
ID int
|
|
Updated time.Time
|
|
Created time.Time
|
|
}
|
|
|
|
type SchemaTable struct {
|
|
Name string
|
|
Sql string
|
|
}
|
|
|
|
type Schema struct {
|
|
// Create contains all tables, the key is the name and the
|
|
// value the sql
|
|
Tables []SchemaTable
|
|
|
|
// Drop contains the drop sql
|
|
Drop string
|
|
|
|
// Require add schema before create and delete after
|
|
Require []Schema
|
|
}
|
|
|
|
func (s Schema) Create(db *sqlx.DB) error {
|
|
for _, sch := range s.Require {
|
|
err := sch.Create(db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
for _, table := range s.Tables {
|
|
_, err := db.Exec(table.Sql)
|
|
if err != nil {
|
|
return fmt.Errorf("%s\n%s", err, table.Sql)
|
|
}
|
|
trigger := fmt.Sprintf("CREATE TRIGGER update_%s BEFORE UPDATE ON %s FOR EACH ROW EXECUTE PROCEDURE update_modified_column();", table.Name, table.Name)
|
|
_, err = db.Exec(trigger)
|
|
if err != nil {
|
|
return fmt.Errorf("%s\n%s", err, trigger)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s Schema) Delete(db *sqlx.DB) error {
|
|
for _, sch := range s.Require {
|
|
err := sch.Delete(db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
err := MultiExec(db, s.Drop)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func MultiExec(e sqlx.Execer, query string) error {
|
|
stmts := strings.Split(query, ";\n")
|
|
if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 {
|
|
stmts = stmts[:len(stmts)-1]
|
|
}
|
|
for _, s := range stmts {
|
|
_, err := e.Exec(s)
|
|
if err != nil {
|
|
return fmt.Errorf("%s\n%s", err, s)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func RunWithSchema(db *sqlx.DB, schema Schema, t *testing.T, test func(db *sqlx.DB, t *testing.T)) {
|
|
defer func() {
|
|
err := schema.Delete(db)
|
|
if err != nil {
|
|
t.Fatalf("%s", err.Error())
|
|
}
|
|
}()
|
|
|
|
err := schema.Create(db)
|
|
if err != nil {
|
|
t.Fatalf("%s", err.Error())
|
|
}
|
|
|
|
test(db, t)
|
|
}
|