95 lines
1.9 KiB
Go
95 lines
1.9 KiB
Go
package auth
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"github.com/gorilla/context"
|
|
)
|
|
|
|
type key int
|
|
|
|
const ukey key = 0 //user key
|
|
|
|
// Middleware get User from session and put it in context
|
|
type Middleware struct {
|
|
authorizer *Authorizer
|
|
}
|
|
|
|
func NewMiddleware(authorizer *Authorizer) *Middleware {
|
|
return &Middleware{authorizer}
|
|
}
|
|
|
|
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
|
user, err := m.authorizer.CurrentUser(w, r)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
context.Set(r, ukey, user)
|
|
next(w, r)
|
|
}
|
|
|
|
type MiddlewareRole struct {
|
|
authorizer *Authorizer
|
|
role string
|
|
loginPageGetter func() string
|
|
}
|
|
|
|
func NewMiddlewareRole(authorizer *Authorizer, loginPageGetter func() string, role string) *MiddlewareRole {
|
|
return &MiddlewareRole{authorizer, role, loginPageGetter}
|
|
}
|
|
|
|
func (m *MiddlewareRole) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
|
user := GetCurrentUser(r)
|
|
|
|
if user == nil || !user.HasRole(m.role) {
|
|
cookie, err := m.authorizer.Cookiejar.Get(r, "rlogin")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
cookie.Values["redirect"] = r.URL.Path
|
|
err = cookie.Save(r, w)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
http.Redirect(w, r, m.loginPageGetter(), http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
|
|
next(w, r)
|
|
}
|
|
|
|
func GetPostLoginRedirect(a *Authorizer, w http.ResponseWriter, r *http.Request) (string, error) {
|
|
cookie, err := a.Cookiejar.Get(r, "rlogin")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
val := cookie.Values["redirect"]
|
|
if val == nil {
|
|
return "", nil
|
|
}
|
|
path, ok := val.(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid redirect type")
|
|
}
|
|
cookie.Values["rlogin"] = ""
|
|
err = cookie.Save(r, w)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return path, nil
|
|
|
|
}
|
|
|
|
func GetCurrentUser(r *http.Request) User {
|
|
u := context.Get(r, ukey)
|
|
if u == nil {
|
|
return nil
|
|
}
|
|
user, ok := u.(User)
|
|
if !ok {
|
|
panic("invalid user type")
|
|
}
|
|
return user
|
|
}
|