diff --git a/auth/auth.go b/auth/auth.go index 9c762a6..58bda33 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -11,12 +11,25 @@ import ( var ( // ErrInvalidPassword returned when password and hash don't match ErrInvalidPassword = fmt.Errorf("Invalid password") + // ErrInvalidSecret returned when cookie's secret is don't match + ErrInvalidSecret = fmt.Errorf("Invalid secret") // ErrCorrupted returned when session have been corrupted ErrCorrupted = fmt.Errorf("Corrupted session") ) +// UserBackend interface for user backend +type UserBackend interface { + Get(username string) (User, error) +} + +// User interface for user +type User interface { + GetHash() string +} + // Authorizer handle sesssion type Authorizer struct { + backend UserBackend cookiejar *sessions.CookieStore cookieName string peeper string @@ -25,8 +38,9 @@ type Authorizer struct { // New Authorizer peeper is like a salt but not stored in database, // cost is the bcrypt cost for hashing the password -func New(peeper, cookieName, cookieKey string, cost int) *Authorizer { +func New(backend UserBackend, peeper, cookieName, cookieKey string, cost int) *Authorizer { return &Authorizer{ + backend: backend, cookiejar: sessions.NewCookieStore([]byte(cookieKey)), cookieName: cookieName, peeper: peeper, @@ -44,18 +58,31 @@ func (a *Authorizer) GenHash(password string) (string, error) { } // Login cheks password and updates cookie info -func (a *Authorizer) Login(rw http.ResponseWriter, req *http.Request, username, hash, password string) error { +func (a *Authorizer) Login(rw http.ResponseWriter, req *http.Request, username, password string) error { cookie, err := a.cookiejar.Get(req, a.cookieName) if err != nil { return err } - err = bcrypt.CompareHashAndPassword([]byte(hash), []byte(password+a.peeper)) + u, err := a.backend.Get(username) + if err != nil { + return err + } + + err = bcrypt.CompareHashAndPassword([]byte(u.GetHash()), []byte(password+a.peeper)) if err != nil { return ErrInvalidPassword } cookie.Values["username"] = username + + // genereate secret + b, err := bcrypt.GenerateFromPassword([]byte(u.GetHash()), a.cost) + if err != nil { + return err + } + cookie.Values["secret"] = string(b) + err = cookie.Save(req, rw) if err != nil { return err @@ -69,7 +96,8 @@ func (a *Authorizer) Logout(rw http.ResponseWriter, req *http.Request) error { if err != nil { return err } - cookie.Values["username"] = "" + cookie.Values["username"] = nil + cookie.Values["secret"] = nil cookie.Options.MaxAge = -1 // kill the cookie err = cookie.Save(req, rw) if err != nil { @@ -79,20 +107,43 @@ func (a *Authorizer) Logout(rw http.ResponseWriter, req *http.Request) error { } // CurrentUser returns the logged in username from session -func (a *Authorizer) CurrentUser(rw http.ResponseWriter, req *http.Request) (string, error) { +func (a *Authorizer) CurrentUser(rw http.ResponseWriter, req *http.Request) (User, error) { cookie, err := a.cookiejar.Get(req, a.cookieName) if err != nil { - return "", err + return nil, err + } + if cookie.IsNew { + return nil, nil } - username := cookie.Values["username"] - - if !cookie.IsNew && username != nil { - str, ok := username.(string) - if !ok { - return "", ErrCorrupted - } - return str, nil + usernameTmp := cookie.Values["username"] + if usernameTmp == nil { + return nil, nil } - return "", nil + username, ok := usernameTmp.(string) + if !ok { + return nil, ErrCorrupted + } + + u, err := a.backend.Get(username) + if err != nil { + return nil, err + } + + // Check secret + hash := u.GetHash() + secretTmp := cookie.Values["secret"] + if secretTmp == nil { + return nil, nil + } + secret, ok := secretTmp.(string) + if !ok { + return nil, ErrCorrupted + } + err = bcrypt.CompareHashAndPassword([]byte(secret), []byte(hash)) + if err != nil { + return nil, ErrInvalidSecret + } + + return u, nil } diff --git a/auth/auth_test.go b/auth/auth_test.go index 8f95d39..55e7c1c 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -23,9 +23,40 @@ const ( hash = "$2a$10$eVye8xbs6nj4TWnlTmifRuBsAU3F2hkxEcFz9WXdYjUuE6uKLVuzK" ) +type user struct { + username string + password string + hash string +} + +func (u *user) GetHash() string { + return u.hash +} + +type Backend struct { + user *user +} + +func (b *Backend) Get(username string) (User, error) { + if username == b.user.username { + return b.user, nil + } + return nil, fmt.Errorf("invalid username") +} + +func getBackend() *Backend { + return &Backend{ + user: &user{ + username: username, + password: password, + hash: hash, + }, + } +} + func login(w http.ResponseWriter, r *http.Request) { - a := New(peeper, cookie, key, cost) - err := a.Login(w, r, username, hash, password) + a := New(getBackend(), peeper, cookie, key, cost) + err := a.Login(w, r, username, password) if err != nil { fmt.Fprintf(w, "%s", err) w.WriteHeader(http.StatusInternalServerError) @@ -35,7 +66,7 @@ func login(w http.ResponseWriter, r *http.Request) { } func logout(w http.ResponseWriter, r *http.Request) { - a := New(peeper, cookie, key, cost) + a := New(getBackend(), peeper, cookie, key, cost) err := a.Logout(w, r) if err != nil { fmt.Fprintf(w, "%s", err) @@ -45,7 +76,7 @@ func logout(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } func check(w http.ResponseWriter, r *http.Request) { - a := New(peeper, cookie, key, cost) + a := New(getBackend(), peeper, cookie, key, cost) u, err := a.CurrentUser(w, r) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -53,7 +84,14 @@ func check(w http.ResponseWriter, r *http.Request) { return } w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, "%s", u) + if u != nil { + usr, ok := u.(*user) + if !ok { + fmt.Fprintf(w, "Invalid user type") + return + } + fmt.Fprintf(w, "%s", usr.username) + } }