199 lines
4.6 KiB
Go
199 lines
4.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/http"
|
|
|
|
"git.quimbo.fr/odwrtw/canape/backend/models"
|
|
"github.com/gofrs/uuid"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type ipContextKey string
|
|
type requestIDContextKey string
|
|
type authContextKey string
|
|
|
|
// Middleware get User from session and put it in context
|
|
type Middleware struct {
|
|
authorizer *Authorizer
|
|
log *logrus.Entry
|
|
}
|
|
|
|
// NewMiddleware returns a new authentication middleware
|
|
func NewMiddleware(authorizer *Authorizer, log *logrus.Entry) *Middleware {
|
|
return &Middleware{
|
|
authorizer: authorizer,
|
|
log: log.WithField("middleware", "auth"),
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
|
user, err := m.authorizer.CurrentUser(w, r)
|
|
switch err {
|
|
case nil:
|
|
m.log.Debugf("setting user %s in the context", user.Name)
|
|
ctxKey := authContextKey("auth.user")
|
|
ctx := context.WithValue(r.Context(), ctxKey, user)
|
|
r = r.WithContext(ctx)
|
|
case ErrUnauthenticatedUser:
|
|
m.log.Debugf("unauthenticated user")
|
|
case ErrInvalidToken:
|
|
m.log.Debugf("user has an invalid token")
|
|
default:
|
|
m.log.Error(err)
|
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
next(w, r)
|
|
}
|
|
|
|
// MiddlewareRole handles the role checking for the current user
|
|
type MiddlewareRole struct {
|
|
authorizer *Authorizer
|
|
log *logrus.Entry
|
|
role string
|
|
}
|
|
|
|
// NewMiddlewareRole returns a new MiddlewareRole
|
|
func NewMiddlewareRole(authorizer *Authorizer, log *logrus.Entry, role string) *MiddlewareRole {
|
|
return &MiddlewareRole{
|
|
authorizer: authorizer,
|
|
log: log.WithField("middleware", "role"),
|
|
role: role,
|
|
}
|
|
}
|
|
|
|
func (m *MiddlewareRole) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
|
user := GetCurrentUser(r, m.log)
|
|
|
|
if user == nil || !user.HasRole(m.role) {
|
|
if user == nil {
|
|
m.log.Debug("user is nil in the context")
|
|
} else {
|
|
m.log.Debug("user doesn't have the role")
|
|
}
|
|
|
|
// return unauthorized
|
|
http.Error(w, "Invalid user role", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if !user.IsActivated() {
|
|
// return unauthorized
|
|
http.Error(w, "User is not activated", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
m.log.Debug("user has the role and is activated, continuing")
|
|
|
|
next(w, r)
|
|
}
|
|
|
|
// GetCurrentUser gets the current user from the request context
|
|
func GetCurrentUser(r *http.Request, log *logrus.Entry) *models.User {
|
|
log.Debug("getting user from context")
|
|
|
|
ctxKey := authContextKey("auth.user")
|
|
u := r.Context().Value(ctxKey)
|
|
if u == nil {
|
|
return nil
|
|
}
|
|
user, ok := u.(*models.User)
|
|
if !ok {
|
|
panic("invalid user type")
|
|
}
|
|
return user
|
|
}
|
|
|
|
// IPMiddleware set the IP in the request context
|
|
type IPMiddleware struct {
|
|
log *logrus.Entry
|
|
}
|
|
|
|
// NewIPMiddleware returns a new ip middleware
|
|
func NewIPMiddleware(log *logrus.Entry) *IPMiddleware {
|
|
return &IPMiddleware{
|
|
log: log.WithField("middleware", "ip"),
|
|
}
|
|
}
|
|
|
|
func (m *IPMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
|
ip := getRequestIP(r)
|
|
|
|
ctxKey := ipContextKey("ip")
|
|
ctx := context.WithValue(r.Context(), ctxKey, ip)
|
|
r = r.WithContext(ctx)
|
|
|
|
next(w, r)
|
|
}
|
|
|
|
func getRequestIP(req *http.Request) string {
|
|
// Try to get the IP from this header
|
|
var ip = req.Header.Get("X-Real-IP")
|
|
if ip != "" {
|
|
return ip
|
|
}
|
|
|
|
// Or from this one
|
|
ip = req.Header.Get("X-Forwarded-For")
|
|
if ip != "" {
|
|
return ip
|
|
}
|
|
|
|
host, _, err := net.SplitHostPort(req.RemoteAddr)
|
|
if err != nil {
|
|
// fake result
|
|
return "0.0.0.0"
|
|
}
|
|
|
|
hostIP := net.ParseIP(host)
|
|
if host == "" {
|
|
return "0.0.0.0"
|
|
}
|
|
|
|
// Default to the IP from the request
|
|
return hostIP.String()
|
|
}
|
|
|
|
func getIPFromRequest(r *http.Request) string {
|
|
return r.Context().Value(ipContextKey("ip")).(string)
|
|
}
|
|
|
|
// RequestIDMiddleware set the request ID in the request context
|
|
type RequestIDMiddleware struct {
|
|
log *logrus.Entry
|
|
}
|
|
|
|
//TODO: Move this somewhere else
|
|
|
|
// NewRequestIDMiddleware returns a new requestID middleware
|
|
func NewRequestIDMiddleware(log *logrus.Entry) *RequestIDMiddleware {
|
|
return &RequestIDMiddleware{
|
|
log: log.WithField("middleware", "request_id"),
|
|
}
|
|
}
|
|
|
|
func (m *RequestIDMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
|
ctxKey := requestIDContextKey("request_id")
|
|
reqID, err := uuid.NewV4()
|
|
if err == nil {
|
|
ctx := context.WithValue(r.Context(), ctxKey, reqID.String())
|
|
r = r.WithContext(ctx)
|
|
r.Header.Set("X-Request-Id", reqID.String())
|
|
w.Header().Set("X-Request-Id", reqID.String())
|
|
}
|
|
|
|
next(w, r)
|
|
}
|
|
|
|
// GetRequestIDFromRequest returns the request id from the request
|
|
func GetRequestIDFromRequest(r *http.Request) string {
|
|
req := r.Context().Value(requestIDContextKey("request_id"))
|
|
if req != nil {
|
|
return req.(string)
|
|
}
|
|
return ""
|
|
}
|