canape/sqly/sqly.go

116 lines
2.2 KiB
Go

package sqly
import (
"fmt"
"strings"
"testing"
"time"
"github.com/jmoiron/sqlx"
)
const initDBQuery = `
CREATE OR REPLACE FUNCTION update_modified_column()
RETURNS TRIGGER AS $$
BEGIN NEW.updated = now(); RETURN NEW; END; $$ language 'plpgsql';
`
// InitDB add some function to the database and template
func InitDB(db *sqlx.DB) error {
err := MultiExec(db, initDBQuery)
if err != nil {
return err
}
return nil
}
// BaseTable have to be embeded in all your struct which reflect a table
type BaseTable struct {
ID int
Updated time.Time
Created time.Time
}
type SchemaTable struct {
Name string
Sql string
}
type Schema struct {
// Create contains all tables, the key is the name and the
// value the sql
Tables []SchemaTable
// Drop contains the drop sql
Drop string
// Require add schema before create and delete after
Require []Schema
}
func (s Schema) Create(db *sqlx.DB) error {
for _, sch := range s.Require {
err := sch.Create(db)
if err != nil {
return err
}
}
for _, table := range s.Tables {
_, err := db.Exec(table.Sql)
if err != nil {
return fmt.Errorf("%s\n%s", err, table.Sql)
}
trigger := fmt.Sprintf("CREATE TRIGGER update_%s BEFORE UPDATE ON %s FOR EACH ROW EXECUTE PROCEDURE update_modified_column();", table.Name, table.Name)
_, err = db.Exec(trigger)
if err != nil {
return fmt.Errorf("%s\n%s", err, trigger)
}
}
return nil
}
func (s Schema) Delete(db *sqlx.DB) error {
err := MultiExec(db, s.Drop)
if err != nil {
return err
}
for _, sch := range s.Require {
err := sch.Delete(db)
if err != nil {
return err
}
}
return nil
}
func MultiExec(e sqlx.Execer, query string) error {
stmts := strings.Split(query, ";\n")
if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 {
stmts = stmts[:len(stmts)-1]
}
for _, s := range stmts {
_, err := e.Exec(s)
if err != nil {
return fmt.Errorf("%s\n%s", err, s)
}
}
return nil
}
func RunWithSchema(db *sqlx.DB, schema Schema, t *testing.T, test func(db *sqlx.DB, t *testing.T)) {
defer func() {
err := schema.Delete(db)
if err != nil {
t.Fatalf("%s", err.Error())
}
}()
err := schema.Create(db)
if err != nil {
t.Fatalf("%s", err.Error())
}
test(db, t)
}