package auth import ( "fmt" "io/ioutil" "net/http" "net/http/cookiejar" "net/http/httptest" "testing" "github.com/gorilla/mux" "github.com/kr/pretty" ) const ( pepper = "polp" ckey = "plop" cookieName = "auth" cost = 10 username = "plop" password = "ploppwd" hash = "$2a$10$eVye8xbs6nj4TWnlTmifRuBsAU3F2hkxEcFz9WXdYjUuE6uKLVuzK" ) type user struct { username string password string hash string } func (u *user) GetHash() string { return u.hash } type Backend struct { user *user } func (b *Backend) Get(username string) (User, error) { if username == b.user.username { return b.user, nil } return nil, fmt.Errorf("invalid username") } func getBackend() *Backend { return &Backend{ user: &user{ username: username, password: password, hash: hash, }, } } func login(w http.ResponseWriter, r *http.Request) { a := New(getBackend(), pepper, cookieName, ckey, cost) err := a.Login(w, r, username, password) if err != nil { fmt.Fprintf(w, "%s", err) w.WriteHeader(http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) } func logout(w http.ResponseWriter, r *http.Request) { a := New(getBackend(), pepper, cookieName, ckey, cost) err := a.Logout(w, r) if err != nil { fmt.Fprintf(w, "%s", err) w.WriteHeader(http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) } func check(w http.ResponseWriter, r *http.Request) { a := New(getBackend(), pepper, cookieName, key, cost) u, err := a.CurrentUser(w, r) if err != nil { w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, "%s", err) return } w.WriteHeader(http.StatusOK) if u != nil { usr, ok := u.(*user) if !ok { fmt.Fprintf(w, "Invalid user type") return } fmt.Fprintf(w, "%s", usr.username) } } func handlers() *mux.Router { r := mux.NewRouter() r.HandleFunc("/login", login).Methods("GET") r.HandleFunc("/logout", logout).Methods("GET") r.HandleFunc("/check", check).Methods("GET") return r } func TestAuth(t *testing.T) { ts := httptest.NewServer(handlers()) defer ts.Close() cookieJar, _ := cookiejar.New(nil) client := &http.Client{ Jar: cookieJar, } // Check no user logged in = res, err := client.Get(ts.URL + "/check") if err != nil { t.Fatal(err) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) } if res.StatusCode != http.StatusOK { t.Fatal(body) } if string(body) != "" { t.Fatalf("No user logged in expected but found: %s", body) } // Login res, err = client.Get(ts.URL + "/login") if err != nil { t.Fatal(err) } body, err = ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) } if res.StatusCode != http.StatusOK { t.Fatal(string(body)) } // Checks we are logged in res, err = client.Get(ts.URL + "/check") if err != nil { t.Fatal(err) } body, err = ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) } if res.StatusCode != http.StatusOK { pretty.Println(res.StatusCode) t.Fatal(body) } if string(body) != username { t.Fatalf("We expect be logged in as %s but we got: %s", username, body) } // Logout res, err = client.Get(ts.URL + "/logout") if err != nil { t.Fatal(err) } body, err = ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) } if res.StatusCode != http.StatusOK { t.Fatal(string(body)) } // Check no username logged in anymore res, err = client.Get(ts.URL + "/check") if err != nil { t.Fatal(err) } body, err = ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) } if res.StatusCode != http.StatusOK { t.Fatal(body) } if string(body) != "" { t.Fatalf("No user logged in expected but found: %s", body) } }