diff --git a/internal/web/handlers_auth.go b/internal/web/handlers_auth.go index 022c81c..d1ae4b3 100644 --- a/internal/web/handlers_auth.go +++ b/internal/web/handlers_auth.go @@ -17,26 +17,20 @@ import ( "golang.org/x/crypto/bcrypt" ) -const ISSUER = "nfeeder-app" - -type AuthResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` -} - -type RefreshTokenReq struct { - RefreshToken string `json:"refresh_token"` -} - func (s *Server) handleRegister() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { email := r.FormValue("email") password := r.FormValue("password") + if email == "" || password == "" { + WriteJSON(w, http.StatusBadRequest, ErrorResponse{Error: "Email and password required"}) + return + } + // hash password hashpw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - http.Error(w, "Internal Error", http.StatusInternalServerError) + WriteJSON(w, http.StatusInternalServerError, ErrorResponse{Error: "Internal Error"}) return } @@ -47,8 +41,8 @@ func (s *Server) handleRegister() http.HandlerFunc { }) if err != nil { // Log the actual error for yourself, send a generic one to the user - log.Printf("failed to create user: %v", err) - http.Error(w, "Could not create user", http.StatusBadRequest) + log.Printf("Failed to create user: %v", err) + WriteJSON(w, http.StatusBadRequest, ErrorResponse{Error: "Failed to create user"}) return } @@ -64,7 +58,7 @@ func (s *Server) handleLogin() http.HandlerFunc { user, err := s.store.GetUserByEmail(r.Context(), email) if err != nil { log.Printf("failed to get user: %v", err) - http.Error(w, "Invalid credentials", http.StatusUnauthorized) + WriteJSON(w, http.StatusUnauthorized, ErrorResponse{Error: "Invalid credentials"}) return } @@ -72,7 +66,7 @@ func (s *Server) handleLogin() http.HandlerFunc { err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) if err != nil { log.Printf("incorrect password entered: %v", err) - http.Error(w, "Invalid credentials", http.StatusUnauthorized) + WriteJSON(w, http.StatusUnauthorized, ErrorResponse{Error: "Invalid credentials"}) return } @@ -87,7 +81,7 @@ func (s *Server) handleLogout() http.HandlerFunc { err := json.NewDecoder(r.Body).Decode(&req) if err != nil { log.Printf("handleRefresh failed to parse request body: %v", err) - http.Error(w, "Invalid JSON", http.StatusBadRequest) + WriteJSON(w, http.StatusBadRequest, ErrorResponse{Error: "Invalid JSON"}) return } @@ -97,8 +91,7 @@ func (s *Server) handleLogout() http.HandlerFunc { tokenRecord, err := s.store.GetValidRefreshToken(r.Context(), hex_token) if err != nil { log.Printf("Logout succes, warning: token not found or expired: %v", err) - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{"message": "logout success"}) + WriteJSON(w, http.StatusOK, SuccessResponse{Message: "logout success"}) return } @@ -110,8 +103,9 @@ func (s *Server) handleLogout() http.HandlerFunc { log.Printf("Failed to delete old refresh token: %v", err) } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{"message": "logout success"}) + WriteJSON(w, http.StatusOK, SuccessResponse{ + Message: "logout success", + }) } } @@ -122,7 +116,7 @@ func (s *Server) handleRefresh() http.HandlerFunc { err := json.NewDecoder(r.Body).Decode(&req) if err != nil { log.Printf("handleRefresh failed to parse request body: %v", err) - http.Error(w, "Invalid JSON", http.StatusBadRequest) + WriteJSON(w, http.StatusBadRequest, ErrorResponse{Error: "Invalid JSON"}) return } @@ -132,7 +126,7 @@ func (s *Server) handleRefresh() http.HandlerFunc { tokenRecord, err := s.store.GetValidRefreshToken(r.Context(), hex_token) if err != nil { log.Printf("Refresh failed - token not found or expired: %v", err) - http.Error(w, "Invalid or expired session", http.StatusUnauthorized) + WriteJSON(w, http.StatusUnauthorized, ErrorResponse{Error: "Invalid or expired session"}) return } @@ -144,7 +138,7 @@ func (s *Server) handleRefresh() http.HandlerFunc { log.Printf("Failed to delete old refresh token: %v", err) } - s.issueToken(w, r, tokenRecord.UserID, true) + s.issueToken(w, r, tokenRecord.UserID, false) } } @@ -162,27 +156,22 @@ func (s *Server) issueToken(w http.ResponseWriter, r *http.Request, userID int64 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString(s.jwtSecret) if err != nil { - http.Error(w, "Internal Error", http.StatusInternalServerError) + WriteJSON(w, http.StatusInternalServerError, ErrorResponse{ + Error: "Internal Server Error", + }) return } refreshToken := s.issueRefreshToken(r.Context(), userID, nowTime, checkLimit) if refreshToken == "" { - http.Error(w, "Internal Error", http.StatusInternalServerError) + WriteJSON(w, http.StatusInternalServerError, ErrorResponse{Error: "Internal Server Error"}) return } - resp := AuthResponse{ + WriteJSON(w, http.StatusOK, AuthResponse{ AccessToken: tokenString, RefreshToken: refreshToken, - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(resp); err != nil { - log.Printf("json encoding failed: %v", err) - return - } + }) } func (s *Server) issueRefreshToken(ctx context.Context, userID int64, issuedTime time.Time, checkLimit bool) string { diff --git a/internal/web/middleware.go b/internal/web/middleware.go index ba7d1a6..4df797d 100644 --- a/internal/web/middleware.go +++ b/internal/web/middleware.go @@ -2,9 +2,9 @@ package web import ( "context" - "encoding/json" "errors" "fmt" + "log" "net/http" "github.com/golang-jwt/jwt/v5" @@ -32,9 +32,9 @@ func (s *Server) hasAuth(next http.Handler) http.Handler { } if tokenString == "" { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"}) + WriteJSON(w, http.StatusUnauthorized, ErrorResponse{ + Error: "unauthorized", + }) return } @@ -54,27 +54,26 @@ func (s *Server) hasAuth(next http.Handler) http.Handler { /// Not Before Check: It checks if the nbf (Not Before) time has passed. /// Issued At Check: It ensures the iat isn't in the future. if err != nil || !token.Valid { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - // Check if the error is specifically because the token expired if errors.Is(err, jwt.ErrTokenExpired) { - json.NewEncoder(w).Encode(map[string]string{ - "error": "token_expired", - "message": "Please use your refresh token to get a new session", + WriteJSON(w, http.StatusUnauthorized, ErrorResponse{ + Error: "Token expired: Please use your refresh token to get a new session", }) return } - json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"}) + WriteJSON(w, http.StatusUnauthorized, ErrorResponse{ + Error: "unauthorized", + }) return } // Verify issuer if claims.Issuer != ISSUER { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - json.NewEncoder(w).Encode(map[string]string{"error": "invalid issuer"}) + log.Printf("Invalid Token, issuer incorrect or tampered with") + WriteJSON(w, http.StatusUnauthorized, ErrorResponse{ + Error: "Invalid Token", + }) return } diff --git a/internal/web/responses.go b/internal/web/responses.go new file mode 100644 index 0000000..1fa9c92 --- /dev/null +++ b/internal/web/responses.go @@ -0,0 +1,18 @@ +package web + +type ErrorResponse struct { + Error string `json:"error"` +} + +type SuccessResponse struct { + Message string `json:"message"` +} + +type AuthResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +type RefreshTokenReq struct { + RefreshToken string `json:"refresh_token"` +} diff --git a/internal/web/utils.go b/internal/web/utils.go index e0646a0..f61ec4a 100644 --- a/internal/web/utils.go +++ b/internal/web/utils.go @@ -2,9 +2,25 @@ package web import ( "context" + "encoding/json" + "log" + "net/http" "strconv" ) +// Server will respond with an expected struct of information. +// Otherwise will return the correct error status code. +// Client responsibility to check status then unmarshal to +// error response or data +func WriteJSON[T any](w http.ResponseWriter, status int, data T) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + + if err := json.NewEncoder(w).Encode(data); err != nil { + log.Printf("JSON encode error %v", err) + } +} + func userIDFromContext(ctx context.Context) (int64, bool) { val := ctx.Value(userIDKey) if val == nil {