diff --git a/shows.go b/shows.go index 349edb9..fff7897 100644 --- a/shows.go +++ b/shows.go @@ -39,8 +39,17 @@ var Schema = sqly.Schema{ created timestamp DEFAULT current_timestamp ); `}, + sqly.SchemaTable{ + Name: "shows_tracked", + Sql: ` + CREATE TABLE shows_tracked ( + shows_id integer NOT NULL REFERENCES shows (id) ON DELETE CASCADE, + users_id integer NOT NULL REFERENCES users (id) ON DELETE CASCADE + ); + `}, }, Drop: ` + DROP TABLE shows_tracked; DROP TABLE episodes; DROP TABLE shows; `, @@ -53,12 +62,19 @@ const ( addEpisodeQuery = `INSERT INTO episodes (shows_id, title, season, episode) VALUES ($1,$2,$3,$4);` getEpisodesQuery = `SELECT title, season, episode FROM episodes WHERE shows_id=$1;` + + getShowWithUserQuery = ` + SELECT id, imdbid, title, + EXISTS (SELECT 1 FROM shows_tracked WHERE shows_id=shows.id AND users_id=$2) AS tracked + FROM shows WHERE imdbid=$1; + ` ) type Show struct { sqly.BaseTable polochon.Show Episodes []*Episode + Tracked bool } func Get(db *sqlx.DB, imdbID string) (*Show, error) { @@ -70,6 +86,16 @@ func Get(db *sqlx.DB, imdbID string) (*Show, error) { return s, nil } +// GetAsUser returns a show with user info like tracked +func GetAsUser(db *sqlx.DB, user *users.User, imdbID string) (*Show, error) { + s := &Show{} + err := db.QueryRowx(getShowWithUserQuery, imdbID, user.ID).StructScan(s) + if err != nil { + return nil, err + } + return s, nil +} + func (s *Show) Add(db *sqlx.DB) error { var id int r, err := db.NamedQuery(addShowQuery, s) diff --git a/shows_test.go b/shows_test.go index fe6816d..f2f484e 100644 --- a/shows_test.go +++ b/shows_test.go @@ -8,6 +8,7 @@ import ( "testing" "gitlab.quimbo.fr/odwrtw/canape-sql/sqly" + "gitlab.quimbo.fr/odwrtw/canape-sql/users" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" @@ -141,3 +142,39 @@ func TestAddRemoveShow(t *testing.T) { }) } +func TestTrackedShow(t *testing.T) { + sqly.RunWithSchema(db, Schema, t, func(db *sqlx.DB, t *testing.T) { + nfo := strings.NewReader(showNFO1) + s := &polochon.Show{} + polochon.ReadNFO(nfo, s) + show := Show{Show: *s} + show.Add(db) + + u := &users.User{Name: "plop"} + err := u.Add(db) + if err != nil { + t.Fatal(err) + } + + show1, err := GetAsUser(db, u, "tt2357547") + if err != nil { + t.Fatal(err) + } + if show1.Tracked { + t.Fatal("Tracked must be false here") + } + + q := `INSERT INTO shows_tracked (shows_id, users_id) VALUES ($1, $2);` + _, err = db.Exec(q, show1.ID, u.ID) + if err != nil { + t.Fatal(err) + } + show2, err := GetAsUser(db, u, "tt2357547") + if err != nil { + t.Fatal(err) + } + if !show2.Tracked { + t.Fatal("Tracked must be true here") + } + }) +} diff --git a/sqly/sqly.go b/sqly/sqly.go index 35eb820..234184e 100644 --- a/sqly/sqly.go +++ b/sqly/sqly.go @@ -71,16 +71,16 @@ func (s Schema) Create(db *sqlx.DB) error { } func (s Schema) Delete(db *sqlx.DB) error { + err := MultiExec(db, s.Drop) + if err != nil { + return err + } for _, sch := range s.Require { err := sch.Delete(db) if err != nil { return err } } - err := MultiExec(db, s.Drop) - if err != nil { - return err - } return nil }