package sqly import ( "fmt" "strings" "testing" "time" "github.com/jmoiron/sqlx" ) const initDBQuery = ` CREATE OR REPLACE FUNCTION update_modified_column() RETURNS TRIGGER AS $$ BEGIN NEW.updated = now(); RETURN NEW; END; $$ language 'plpgsql'; ` // InitDB add some function to the database and template func InitDB(db *sqlx.DB) error { err := MultiExec(db, initDBQuery) if err != nil { return err } return nil } // BaseTable have to be embeded in all your struct which reflect a table type BaseTable struct { ID int Updated time.Time Created time.Time } type SchemaTable struct { Name string Sql string } type Schema struct { // Create contains all tables, the key is the name and the // value the sql Tables []SchemaTable // Drop contains the drop sql Drop string // Require add schema before create and delete after Require []Schema } func (s Schema) Create(db *sqlx.DB) error { for _, sch := range s.Require { err := sch.Create(db) if err != nil { return err } } for _, table := range s.Tables { _, err := db.Exec(table.Sql) if err != nil { return fmt.Errorf("%s\n%s", err, table.Sql) } trigger := fmt.Sprintf("CREATE TRIGGER update_%s BEFORE UPDATE ON %s FOR EACH ROW EXECUTE PROCEDURE update_modified_column();", table.Name, table.Name) _, err = db.Exec(trigger) if err != nil { return fmt.Errorf("%s\n%s", err, trigger) } } return nil } func (s Schema) Delete(db *sqlx.DB) error { for _, sch := range s.Require { err := sch.Delete(db) if err != nil { return err } } err := MultiExec(db, s.Drop) if err != nil { return err } return nil } func MultiExec(e sqlx.Execer, query string) error { stmts := strings.Split(query, ";\n") if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { stmts = stmts[:len(stmts)-1] } for _, s := range stmts { _, err := e.Exec(s) if err != nil { return fmt.Errorf("%s\n%s", err, s) } } return nil } func RunWithSchema(db *sqlx.DB, schema Schema, t *testing.T, test func(db *sqlx.DB, t *testing.T)) { defer func() { err := schema.Delete(db) if err != nil { t.Fatalf("%s", err.Error()) } }() err := schema.Create(db) if err != nil { t.Fatalf("%s", err.Error()) } test(db, t) }