• golang 实现带令牌限流的JWT demo


    demo里提供了三个接口,认证取token,刷新token,获取信息,token过期前也会在header里写上新token(便于客户端更换)

    1. package main
    2. import (
    3. "fmt"
    4. "net/http"
    5. "sync"
    6. "time"
    7. "github.com/gin-gonic/gin"
    8. "github.com/golang-jwt/jwt/v5"
    9. )
    10. const (
    11. TOKEN_SECRET_KEY = "secret" // 密钥
    12. TOKEN_EXPIRE_TIME = 2 * time.Hour // 2小时过期
    13. TOKEN_REFRESH_TIME = 10 * time.Minute // 接近过期时会在header里面加上新token,客户端可以识别也可以自行拉取新
    14. )
    15. var tb *TokenBucket
    16. func init() {
    17. tb = NewTokenBucket(100, 5000)
    18. }
    19. func main() {
    20. // 创建一个Gin引擎
    21. r := gin.Default()
    22. // 限流
    23. r.Use(rateMiddWare(tb))
    24. // 登录接口
    25. r.POST("/login", loginHandler)
    26. // 受保护接口
    27. v1 := r.Group("/v1").Use(authReqMiddWare())
    28. {
    29. v1.GET("/user", userHandler)
    30. v1.GET("/refresh-token", refreshTokenHandler)
    31. }
    32. // 监听并在8080端口上启动服务
    33. r.Run(":8080")
    34. }
    35. /**
    36. * 登陆
    37. * curl 127.0.0.1:8080/login -X POST
    38. * return {"token":"eyJhbGciO..."}
    39. */
    40. func loginHandler(c *gin.Context) {
    41. // TODO 验证账户密码
    42. // account + passwd 需要从db中拉取信息校验
    43. // 获取用户信息
    44. userId := "123"
    45. userName := "test 123"
    46. // 签名JWT
    47. tokenString, err := generateJWTToken(userId, userName)
    48. if err != nil {
    49. c.JSON(http.StatusInternalServerError, gin.H{"message": "Failed to generate token"})
    50. return
    51. }
    52. // 返回JWT给客户端
    53. c.JSON(http.StatusOK, gin.H{"token": tokenString})
    54. }
    55. /**
    56. * 用户信息
    57. * curl 127.0.0.1:8080/v1/user -H "token:eyJhbGciO..."
    58. * return {"UserId":"123","UserName":"test 123","exp":1694741333,"nbf":1694734133,"iat":1694734133}
    59. */
    60. func userHandler(c *gin.Context) {
    61. claims, bool := c.Get("claims")
    62. if bool {
    63. // TODO 其他用户信息可以用UID查 缓存 和 数据库
    64. // findbyId()
    65. c.JSON(http.StatusOK, claims)
    66. return
    67. }
    68. c.JSON(http.StatusOK, gin.H{"message": "not found"})
    69. }
    70. /**
    71. * 刷新token
    72. * curl 127.0.0.1:8080/v1/refresh-token -H "token:eyJhbGciO..."
    73. * return {"token":"eyJhbGciO..."}
    74. */
    75. func refreshTokenHandler(c *gin.Context) {
    76. claims, bool := c.Get("claims")
    77. if !bool {
    78. c.JSON(http.StatusOK, gin.H{"message": "not found claims"})
    79. return
    80. }
    81. fmt.Println(claims)
    82. val, ok := claims.(*jwtClaims)
    83. if !ok {
    84. c.JSON(http.StatusOK, gin.H{"message": "not found"})
    85. return
    86. }
    87. tokenString, err := generateJWTToken(val.UserId, val.UserName)
    88. if err != nil {
    89. c.JSON(http.StatusInternalServerError, gin.H{"message": "Failed to generate token"})
    90. return
    91. }
    92. c.JSON(http.StatusOK, gin.H{"token": tokenString})
    93. return
    94. }
    95. //jwt
    96. type jwtClaims struct {
    97. UserId string
    98. UserName string
    99. jwt.RegisteredClaims // jwt中标准格式
    100. }
    101. /**
    102. * 校验token
    103. * 如果想从服务端控制发出的token,可以通过redis标记也能达到让指定token提前过期的目的
    104. */
    105. func authReqMiddWare() gin.HandlerFunc {
    106. return func(c *gin.Context) {
    107. // 读取TOKEN
    108. tokenStr := c.GetHeader("token")
    109. if tokenStr == "" {
    110. c.JSON(http.StatusForbidden, gin.H{"message": "Token not exist"})
    111. c.Abort()
    112. return
    113. }
    114. // 解析token
    115. token, err := jwt.ParseWithClaims(tokenStr, &jwtClaims{}, func(token *jwt.Token) (interface{}, error) {
    116. return []byte(TOKEN_SECRET_KEY), nil
    117. })
    118. if err != nil {
    119. c.JSON(http.StatusForbidden, gin.H{"message": err.Error()})
    120. c.Abort()
    121. return
    122. }
    123. claims, ok := token.Claims.(*jwtClaims)
    124. // 这里默认会检查ExpiresAt是否过期
    125. if ok && token.Valid {
    126. now := time.Now()
    127. // 检查过期时间,对快要过期的添加http header `refresh-token`
    128. if t := claims.ExpiresAt.Time.Add(-TOKEN_REFRESH_TIME); t.Before(now) {
    129. tokenString, err := generateJWTToken(claims.UserId, claims.UserName)
    130. if err != nil {
    131. c.JSON(http.StatusInternalServerError, gin.H{"message": "Failed to generate token"})
    132. c.Abort()
    133. return
    134. }
    135. c.Header("refresh-token", tokenString) //
    136. }
    137. c.Set("claims", claims)
    138. }
    139. }
    140. }
    141. // 生成JWT token
    142. func generateJWTToken(userId, userName string) (string, error) {
    143. now := time.Now()
    144. claims := jwtClaims{
    145. UserId: userId,
    146. UserName: userName,
    147. RegisteredClaims: jwt.RegisteredClaims{
    148. ExpiresAt: &jwt.NumericDate{Time: time.Now().Add(TOKEN_EXPIRE_TIME)}, // 过期时间
    149. IssuedAt: jwt.NewNumericDate(now), // 签发时间
    150. NotBefore: jwt.NewNumericDate(now), // 生效时间
    151. },
    152. }
    153. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
    154. return token.SignedString([]byte(TOKEN_SECRET_KEY))
    155. }
    156. // 限流
    157. func rateMiddWare(tb *TokenBucket) gin.HandlerFunc {
    158. return func(c *gin.Context) {
    159. if !tb.AllowRequest() {
    160. c.JSON(http.StatusTooManyRequests, gin.H{"message": http.StatusText(http.StatusTooManyRequests)})
    161. c.Abort()
    162. return
    163. }
    164. }
    165. }
    166. // 令牌
    167. type TokenBucket struct {
    168. cap int // 桶容量
    169. rate float64 // 每秒生产个数
    170. tokenNum int // 当前计数
    171. lastTime time.Time // 上一个产生时间
    172. mu sync.Mutex
    173. }
    174. func NewTokenBucket(cap int, rate float64) *TokenBucket {
    175. return &TokenBucket{
    176. cap: cap,
    177. rate: rate,
    178. tokenNum: cap,
    179. lastTime: time.Now(),
    180. }
    181. }
    182. // 拿令牌
    183. func (tb *TokenBucket) AllowRequest() bool {
    184. tb.mu.Lock()
    185. defer tb.mu.Unlock()
    186. now := time.Now()
    187. second := now.Sub(tb.lastTime).Seconds() // 计算经过多少秒
    188. newTokens := int(second * tb.rate) // 计算产生的令牌数量
    189. if newTokens > 0 {
    190. tb.tokenNum = tb.tokenNum + newTokens
    191. if tb.tokenNum > tb.cap { // 不能超过容量
    192. tb.tokenNum = tb.cap
    193. }
    194. tb.lastTime = now
    195. }
    196. if tb.tokenNum > 0 {
    197. tb.tokenNum--
    198. return true
    199. }
    200. return false
    201. }

  • 相关阅读:
    【C++高阶(二)】熟悉STL中的map和set --了解KV模型和pair结构
    两化融合贯标是指什么
    制造企业为什么要部署数字化工厂系统
    ATmega128定时器设计的音乐门铃(标签-算法|关键词-工作模式)
    css定位
    第十八章 ObjectScript - 使用例程
    Rsync远程同步
    Java 20 新功能介绍
    springboot整合阿里大于并结合mq发送短信
    工业控制系统协议相关的安全问题
  • 原文地址:https://blog.csdn.net/lucifer_qiao/article/details/132942319