canape/backend/auth/middleware.go

199 lines
4.6 KiB
Go

package auth
import (
"context"
"net"
"net/http"
"git.quimbo.fr/odwrtw/canape/backend/models"
"github.com/gofrs/uuid"
"github.com/sirupsen/logrus"
)
type ipContextKey string
type requestIDContextKey string
type authContextKey string
// Middleware get User from session and put it in context
type Middleware struct {
authorizer *Authorizer
log *logrus.Entry
}
// NewMiddleware returns a new authentication middleware
func NewMiddleware(authorizer *Authorizer, log *logrus.Entry) *Middleware {
return &Middleware{
authorizer: authorizer,
log: log.WithField("middleware", "auth"),
}
}
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
user, err := m.authorizer.CurrentUser(w, r)
switch err {
case nil:
m.log.Debugf("setting user %s in the context", user.Name)
ctxKey := authContextKey("auth.user")
ctx := context.WithValue(r.Context(), ctxKey, user)
r = r.WithContext(ctx)
case ErrUnauthenticatedUser:
m.log.Debugf("unauthenticated user")
case ErrInvalidToken:
m.log.Debugf("user has an invalid token")
default:
m.log.Error(err)
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next(w, r)
}
// MiddlewareRole handles the role checking for the current user
type MiddlewareRole struct {
authorizer *Authorizer
log *logrus.Entry
role string
}
// NewMiddlewareRole returns a new MiddlewareRole
func NewMiddlewareRole(authorizer *Authorizer, log *logrus.Entry, role string) *MiddlewareRole {
return &MiddlewareRole{
authorizer: authorizer,
log: log.WithField("middleware", "role"),
role: role,
}
}
func (m *MiddlewareRole) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
user := GetCurrentUser(r, m.log)
if user == nil || !user.HasRole(m.role) {
if user == nil {
m.log.Debug("user is nil in the context")
} else {
m.log.Debug("user doesn't have the role")
}
// return unauthorized
http.Error(w, "Invalid user role", http.StatusUnauthorized)
return
}
if !user.IsActivated() {
// return unauthorized
http.Error(w, "User is not activated", http.StatusUnauthorized)
return
}
m.log.Debug("user has the role and is activated, continuing")
next(w, r)
}
// GetCurrentUser gets the current user from the request context
func GetCurrentUser(r *http.Request, log *logrus.Entry) *models.User {
log.Debug("getting user from context")
ctxKey := authContextKey("auth.user")
u := r.Context().Value(ctxKey)
if u == nil {
return nil
}
user, ok := u.(*models.User)
if !ok {
panic("invalid user type")
}
return user
}
// IPMiddleware set the IP in the request context
type IPMiddleware struct {
log *logrus.Entry
}
// NewIPMiddleware returns a new ip middleware
func NewIPMiddleware(log *logrus.Entry) *IPMiddleware {
return &IPMiddleware{
log: log.WithField("middleware", "ip"),
}
}
func (m *IPMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
ip := getRequestIP(r)
ctxKey := ipContextKey("ip")
ctx := context.WithValue(r.Context(), ctxKey, ip)
r = r.WithContext(ctx)
next(w, r)
}
func getRequestIP(req *http.Request) string {
// Try to get the IP from this header
var ip = req.Header.Get("X-Real-IP")
if ip != "" {
return ip
}
// Or from this one
ip = req.Header.Get("X-Forwarded-For")
if ip != "" {
return ip
}
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
// fake result
return "0.0.0.0"
}
hostIP := net.ParseIP(host)
if host == "" {
return "0.0.0.0"
}
// Default to the IP from the request
return hostIP.String()
}
func getIPFromRequest(r *http.Request) string {
return r.Context().Value(ipContextKey("ip")).(string)
}
// RequestIDMiddleware set the request ID in the request context
type RequestIDMiddleware struct {
log *logrus.Entry
}
//TODO: Move this somewhere else
// NewRequestIDMiddleware returns a new requestID middleware
func NewRequestIDMiddleware(log *logrus.Entry) *RequestIDMiddleware {
return &RequestIDMiddleware{
log: log.WithField("middleware", "request_id"),
}
}
func (m *RequestIDMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
ctxKey := requestIDContextKey("request_id")
reqID, err := uuid.NewV4()
if err == nil {
ctx := context.WithValue(r.Context(), ctxKey, reqID.String())
r = r.WithContext(ctx)
r.Header.Set("X-Request-Id", reqID.String())
w.Header().Set("X-Request-Id", reqID.String())
}
next(w, r)
}
// GetRequestIDFromRequest returns the request id from the request
func GetRequestIDFromRequest(r *http.Request) string {
req := r.Context().Value(requestIDContextKey("request_id"))
if req != nil {
return req.(string)
}
return ""
}