package web import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "encoding/json" "log" "net/http" "nfeeder/internal/db" "strconv" "time" "github.com/golang-jwt/jwt/v5" "github.com/jackc/pgx/v5/pgtype" "golang.org/x/crypto/bcrypt" ) const ISSUER = "nfeeder-app" type AuthResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` } type RefreshTokenReq struct { RefreshToken string `json:"refresh_token"` } func (s *Server) handleRegister() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { email := r.FormValue("email") password := r.FormValue("password") // hash password hashpw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { http.Error(w, "Internal Error", http.StatusInternalServerError) return } // create the user user, err := s.store.CreateUser(r.Context(), db.CreateUserParams{ Email: email, Password: string(hashpw), }) if err != nil { // Log the actual error for yourself, send a generic one to the user log.Printf("failed to create user: %v", err) http.Error(w, "Could not create user", http.StatusBadRequest) return } s.issueToken(w, r, user.ID, false) } } func (s *Server) handleLogin() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { email := r.FormValue("email") password := r.FormValue("password") user, err := s.store.GetUserByEmail(r.Context(), email) if err != nil { log.Printf("failed to get user: %v", err) http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } // compare passwords err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) if err != nil { log.Printf("incorrect password entered: %v", err) http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } s.issueToken(w, r, user.ID, true) } } func (s *Server) handleLogout() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req RefreshTokenReq err := json.NewDecoder(r.Body).Decode(&req) if err != nil { log.Printf("handleRefresh failed to parse request body: %v", err) http.Error(w, "Invalid JSON", http.StatusBadRequest) return } hash := sha256.Sum256([]byte(req.RefreshToken)) hex_token := hex.EncodeToString(hash[:]) tokenRecord, err := s.store.GetValidRefreshToken(r.Context(), hex_token) if err != nil { log.Printf("Logout succes, warning: token not found or expired: %v", err) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"message": "logout success"}) return } err = s.store.DeleteRefreshToken(r.Context(), db.DeleteRefreshTokenParams{ TokenHash: hex_token, UserID: tokenRecord.UserID, }) if err != nil { log.Printf("Failed to delete old refresh token: %v", err) } w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"message": "logout success"}) } } func (s *Server) handleRefresh() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req RefreshTokenReq err := json.NewDecoder(r.Body).Decode(&req) if err != nil { log.Printf("handleRefresh failed to parse request body: %v", err) http.Error(w, "Invalid JSON", http.StatusBadRequest) return } hash := sha256.Sum256([]byte(req.RefreshToken)) hex_token := hex.EncodeToString(hash[:]) tokenRecord, err := s.store.GetValidRefreshToken(r.Context(), hex_token) if err != nil { log.Printf("Refresh failed - token not found or expired: %v", err) http.Error(w, "Invalid or expired session", http.StatusUnauthorized) return } err = s.store.DeleteRefreshToken(r.Context(), db.DeleteRefreshTokenParams{ TokenHash: hex_token, UserID: tokenRecord.UserID, }) if err != nil { log.Printf("Failed to delete old refresh token: %v", err) } s.issueToken(w, r, tokenRecord.UserID, true) } } func (s *Server) issueToken(w http.ResponseWriter, r *http.Request, userID int64, checkLimit bool) { nowTime := time.Now() jwtExpireTime := nowTime.Add(24 * time.Hour) claims := jwt.RegisteredClaims{ Issuer: ISSUER, Subject: strconv.FormatInt(userID, 10), ExpiresAt: jwt.NewNumericDate(jwtExpireTime), IssuedAt: jwt.NewNumericDate(nowTime), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString(s.jwtSecret) if err != nil { http.Error(w, "Internal Error", http.StatusInternalServerError) return } refreshToken := s.issueRefreshToken(r.Context(), userID, nowTime, checkLimit) if refreshToken == "" { http.Error(w, "Internal Error", http.StatusInternalServerError) return } resp := AuthResponse{ AccessToken: tokenString, RefreshToken: refreshToken, } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(resp); err != nil { log.Printf("json encoding failed: %v", err) return } } func (s *Server) issueRefreshToken(ctx context.Context, userID int64, issuedTime time.Time, checkLimit bool) string { refreshExpireTime := issuedTime.Add((24 * time.Hour) * 3) // Generate refresh token rawToken := rand.Text() hash := sha256.Sum256([]byte(rawToken)) hex_token := hex.EncodeToString(hash[:]) // Check the user has not got the session limit. if checkLimit { sessionCount, err := s.store.CountUserRefreshTokens(ctx, userID) if err != nil { log.Printf("failed to count refresh token: %v", err) return "" } // Each user can only have 3 sessions, if they have 3 refresh tokens // remove the oldest one and then create a new refresh token if sessionCount >= int64(MAX_USER_SESSIONS) { err = s.store.DeleteOldestRefreshToken(ctx, userID) if err != nil { log.Printf("failed to delete oldest refresh token: %v", err) return "" } } } _, err := s.store.CreateRefreshToken(ctx, db.CreateRefreshTokenParams{ UserID: userID, TokenHash: hex_token, ExpiresAt: pgtype.Timestamptz{Time: refreshExpireTime, Valid: true}, }) if err != nil { // Log the actual error for yourself, send a generic one to the user log.Printf("failed to create refresh token: %v", err) return "" } return rawToken }