diff --git a/internal/web/handlers_auth.go b/internal/web/handlers_auth.go index a74509d..022c81c 100644 --- a/internal/web/handlers_auth.go +++ b/internal/web/handlers_auth.go @@ -1,6 +1,7 @@ package web import ( + "context" "crypto/rand" "crypto/sha256" "encoding/hex" @@ -23,6 +24,10 @@ type AuthResponse struct { RefreshToken string `json:"refresh_token"` } +type RefreshTokenReq struct { + RefreshToken string `json:"refresh_token"` +} + func (s *Server) handleRegister() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { email := r.FormValue("email") @@ -47,8 +52,7 @@ func (s *Server) handleRegister() http.HandlerFunc { return } - // Auto-login after reg - s.issueToken(w, r, user.ID) + s.issueToken(w, r, user.ID, false) } } @@ -59,6 +63,7 @@ func (s *Server) handleLogin() http.HandlerFunc { user, err := s.store.GetUserByEmail(r.Context(), email) if err != nil { + log.Printf("failed to get user: %v", err) http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } @@ -66,48 +71,86 @@ func (s *Server) handleLogin() http.HandlerFunc { // compare passwords err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) if err != nil { + log.Printf("incorrect password entered: %v", err) http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } - s.issueToken(w, r, user.ID) + s.issueToken(w, r, user.ID, true) } } func (s *Server) handleLogout() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - // To logout, we simply instruct the browser to delete the cookie - http.SetCookie(w, &http.Cookie{ - Name: "nfeeder_token", - Value: "", - Path: "/", - Expires: time.Unix(0, 0), // Expire immediately - HttpOnly: true, + var req RefreshTokenReq + + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + log.Printf("handleRefresh failed to parse request body: %v", err) + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + hash := sha256.Sum256([]byte(req.RefreshToken)) + hex_token := hex.EncodeToString(hash[:]) + + tokenRecord, err := s.store.GetValidRefreshToken(r.Context(), hex_token) + if err != nil { + log.Printf("Logout succes, warning: token not found or expired: %v", err) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"message": "logout success"}) + return + } + + err = s.store.DeleteRefreshToken(r.Context(), db.DeleteRefreshTokenParams{ + TokenHash: hex_token, + UserID: tokenRecord.UserID, }) - http.Redirect(w, r, "/login", http.StatusSeeOther) + if err != nil { + log.Printf("Failed to delete old refresh token: %v", err) + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"message": "logout success"}) } } func (s *Server) handleRefresh() http.HandlerFunc { - // TODO: refresh logic - /* - Read the refresh token cookie - Hash it with sha256 - Look up the hash in the DB - Check it's not expired - Delete the old DB record - Call issueTokenAndRedirect — which creates a new JWT and a new refresh token in one go - */ - return func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "/", http.StatusOK) + var req RefreshTokenReq + + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + log.Printf("handleRefresh failed to parse request body: %v", err) + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + hash := sha256.Sum256([]byte(req.RefreshToken)) + hex_token := hex.EncodeToString(hash[:]) + + tokenRecord, err := s.store.GetValidRefreshToken(r.Context(), hex_token) + if err != nil { + log.Printf("Refresh failed - token not found or expired: %v", err) + http.Error(w, "Invalid or expired session", http.StatusUnauthorized) + return + } + + err = s.store.DeleteRefreshToken(r.Context(), db.DeleteRefreshTokenParams{ + TokenHash: hex_token, + UserID: tokenRecord.UserID, + }) + if err != nil { + log.Printf("Failed to delete old refresh token: %v", err) + } + + s.issueToken(w, r, tokenRecord.UserID, true) } } -func (s *Server) issueToken(w http.ResponseWriter, r *http.Request, userID int64) { +func (s *Server) issueToken(w http.ResponseWriter, r *http.Request, userID int64, checkLimit bool) { nowTime := time.Now() jwtExpireTime := nowTime.Add(24 * time.Hour) - refreshExpireTime := nowTime.Add((24 * time.Hour) * 3) claims := jwt.RegisteredClaims{ Issuer: ISSUER, @@ -123,26 +166,15 @@ func (s *Server) issueToken(w http.ResponseWriter, r *http.Request, userID int64 return } - // Generate refresh token - rawToken := rand.Text() - hash := sha256.Sum256([]byte(rawToken)) - hex_token := hex.EncodeToString(hash[:]) - - _, err = s.store.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{ - UserID: userID, - TokenHash: hex_token, - ExpiresAt: pgtype.Timestamptz{Time: refreshExpireTime, Valid: true}, - }) - if err != nil { - // Log the actual error for yourself, send a generic one to the user - log.Printf("failed to create refresh token: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) + refreshToken := s.issueRefreshToken(r.Context(), userID, nowTime, checkLimit) + if refreshToken == "" { + http.Error(w, "Internal Error", http.StatusInternalServerError) return } resp := AuthResponse{ AccessToken: tokenString, - RefreshToken: rawToken, + RefreshToken: refreshToken, } w.Header().Set("Content-Type", "application/json") @@ -152,3 +184,44 @@ func (s *Server) issueToken(w http.ResponseWriter, r *http.Request, userID int64 return } } + +func (s *Server) issueRefreshToken(ctx context.Context, userID int64, issuedTime time.Time, checkLimit bool) string { + refreshExpireTime := issuedTime.Add((24 * time.Hour) * 3) + + // Generate refresh token + rawToken := rand.Text() + hash := sha256.Sum256([]byte(rawToken)) + hex_token := hex.EncodeToString(hash[:]) + + // Check the user has not got the session limit. + if checkLimit { + sessionCount, err := s.store.CountUserRefreshTokens(ctx, userID) + if err != nil { + log.Printf("failed to count refresh token: %v", err) + return "" + } + + // Each user can only have 3 sessions, if they have 3 refresh tokens + // remove the oldest one and then create a new refresh token + if sessionCount >= int64(MAX_USER_SESSIONS) { + err = s.store.DeleteOldestRefreshToken(ctx, userID) + if err != nil { + log.Printf("failed to delete oldest refresh token: %v", err) + return "" + } + } + } + + _, err := s.store.CreateRefreshToken(ctx, db.CreateRefreshTokenParams{ + UserID: userID, + TokenHash: hex_token, + ExpiresAt: pgtype.Timestamptz{Time: refreshExpireTime, Valid: true}, + }) + if err != nil { + // Log the actual error for yourself, send a generic one to the user + log.Printf("failed to create refresh token: %v", err) + return "" + } + + return rawToken +} diff --git a/internal/web/middleware.go b/internal/web/middleware.go index 893da8e..ba7d1a6 100644 --- a/internal/web/middleware.go +++ b/internal/web/middleware.go @@ -3,6 +3,7 @@ package web import ( "context" "encoding/json" + "errors" "fmt" "net/http" @@ -55,7 +56,17 @@ func (s *Server) hasAuth(next http.Handler) http.Handler { if err != nil || !token.Valid { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) - json.NewEncoder(w).Encode(map[string]string{"error": "invalid or expired token"}) + + // Check if the error is specifically because the token expired + if errors.Is(err, jwt.ErrTokenExpired) { + json.NewEncoder(w).Encode(map[string]string{ + "error": "token_expired", + "message": "Please use your refresh token to get a new session", + }) + return + } + + json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"}) return }