session consumer query함수 리턴 값의 애매함을 제거

This commit is contained in:
2025-08-15 23:55:50 +09:00
parent c0ab2afcf4
commit c24d387761
6 changed files with 40 additions and 56 deletions

View File

@ -32,8 +32,8 @@ func (auth *Authorization) ToStrings() []string {
} }
} }
func (auth *Authorization) Invalidated() bool { func (auth *Authorization) Valid() bool {
return len(auth.invalidated) > 0 return len(auth.invalidated) == 0 && !auth.Account.IsZero()
} }
func MakeAuthrizationFromStringMap(src map[string]string) Authorization { func MakeAuthrizationFromStringMap(src map[string]string) Authorization {
@ -55,7 +55,7 @@ type Provider interface {
} }
type Consumer interface { type Consumer interface {
Query(string) (Authorization, error) Query(string) Authorization
Touch(string) (Authorization, error) Touch(string) (Authorization, error)
IsRevoked(primitive.ObjectID) bool IsRevoked(primitive.ObjectID) bool
Revoke(string) Revoke(string)

View File

@ -263,25 +263,25 @@ func (c *consumer_mongo) query_internal(sk storagekey) (*sessionMongo, bool, err
return nil, false, nil return nil, false, nil
} }
func (c *consumer_mongo) Query(pk string) (Authorization, error) { func (c *consumer_mongo) Query(pk string) Authorization {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
sk := publickey_to_storagekey(publickey(pk)) sk := publickey_to_storagekey(publickey(pk))
si, _, err := c.query_internal(sk) si, _, err := c.query_internal(sk)
if err != nil { if err != nil {
return Authorization{}, err return Authorization{}
} }
if si == nil { if si == nil {
return Authorization{}, nil return Authorization{}
} }
if time.Now().After(si.Ts.Time().Add(c.ttl)) { if time.Now().After(si.Ts.Time().Add(c.ttl)) {
return Authorization{}, nil return Authorization{}
} }
return *si.Auth, nil return *si.Auth
} }
func (c *consumer_mongo) Touch(pk string) (Authorization, error) { func (c *consumer_mongo) Touch(pk string) (Authorization, error) {

View File

@ -2,6 +2,7 @@ package session
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
@ -243,46 +244,49 @@ func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, error) {
expireAt: time.Now().Add(ttl), expireAt: time.Now().Add(ttl),
} }
if auth.Invalidated() { if auth.Valid() {
c.stages[0].deleted[sk] = si
} else {
c.add_internal(sk, si) c.add_internal(sk, si)
} else {
c.stages[0].deleted[sk] = si
} }
return si, nil return si, nil
} }
func (c *consumer_redis) Query(pk string) (Authorization, error) { var errRevoked = errors.New("session revoked")
var errExpired = errors.New("session expired")
func (c *consumer_redis) Query(pk string) Authorization {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
sk := publickey_to_storagekey(publickey(pk)) sk := publickey_to_storagekey(publickey(pk))
if _, deleted := c.stages[0].deleted[sk]; deleted { if _, deleted := c.stages[0].deleted[sk]; deleted {
return Authorization{}, nil return Authorization{}
} }
if _, deleted := c.stages[1].deleted[sk]; deleted { if _, deleted := c.stages[1].deleted[sk]; deleted {
return Authorization{}, nil return Authorization{}
} }
si, err := c.query_internal(sk) si, err := c.query_internal(sk)
if err != nil { if err != nil {
logger.Println("session consumer query :", pk, err) logger.Println("session consumer query :", pk, err)
return Authorization{}, err return Authorization{}
} }
if si == nil { if si == nil {
logger.Println("session consumer query(si nil) :", pk, nil) logger.Println("session consumer query(si nil) :", pk, nil)
return Authorization{}, nil return Authorization{}
} }
if time.Now().After(si.expireAt) { if time.Now().After(si.expireAt) {
logger.Println("session consumer query(expired):", pk, nil) logger.Println("session consumer query(expired):", pk, nil)
return Authorization{}, nil return Authorization{}
} }
return *si.Authorization, nil return *si.Authorization
} }
func (c *consumer_redis) Touch(pk string) (Authorization, error) { func (c *consumer_redis) Touch(pk string) (Authorization, error) {

View File

@ -60,11 +60,11 @@ func TestExpTable(t *testing.T) {
go func() { go func() {
for { for {
q1, err := cs.Query(sk1) q1 := cs.Query(sk1)
logger.Println("query :", q1, err) logger.Println("query :", q1)
q2, err := cs.Query(sk2) q2 := cs.Query(sk2)
logger.Println("query :", q2, err) logger.Println("query :", q2)
time.Sleep(time.Second) time.Sleep(time.Second)
} }
}() }()
@ -87,7 +87,7 @@ func TestExpTable(t *testing.T) {
t.Error(err) t.Error(err)
} }
q2, err := cs2.Query(sk2) q2 := cs2.Query(sk2)
logger.Println("queryf :", q2, err) logger.Println("queryf :", q2)
time.Sleep(20 * time.Second) time.Sleep(20 * time.Second)
} }

View File

@ -683,18 +683,13 @@ func (ws *WebsocketHandler) upgrade_nosession(w http.ResponseWriter, r *http.Req
accid := primitive.ObjectID(*raw) accid := primitive.ObjectID(*raw)
sk := r.Header.Get("AS-X-SESSION") sk := r.Header.Get("AS-X-SESSION")
authinfo, err := ws.sessionConsumer.Query(sk) authinfo := ws.sessionConsumer.Query(sk)
if err != nil { if !authinfo.Valid() {
w.WriteHeader(http.StatusBadRequest)
return
}
if authinfo.Account != accid {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
if authinfo.Invalidated() { if authinfo.Account != accid {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
@ -737,19 +732,8 @@ func (ws *WebsocketHandler) upgrade(w http.ResponseWriter, r *http.Request) {
return return
} }
authinfo, err := ws.sessionConsumer.Query(sk) authinfo := ws.sessionConsumer.Query(sk)
if err != nil { if !authinfo.Valid() {
w.WriteHeader(http.StatusInternalServerError)
logger.Error("authorize query failed :", err)
return
}
if authinfo.Account.IsZero() {
w.WriteHeader(http.StatusUnauthorized)
return
}
if authinfo.Invalidated() {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }

View File

@ -305,10 +305,12 @@ func (ws *websocketPeerHandler[T]) upgrade_noauth(w http.ResponseWriter, r *http
sk := r.Header.Get("AS-X-SESSION") sk := r.Header.Get("AS-X-SESSION")
var accid primitive.ObjectID var accid primitive.ObjectID
if len(sk) > 0 { if len(sk) > 0 {
authinfo, err := ws.sessionConsumer.Query(sk) authinfo := ws.sessionConsumer.Query(sk)
if err == nil { if !authinfo.Valid() {
accid = authinfo.Account w.WriteHeader(http.StatusUnauthorized)
return
} }
accid = authinfo.Account
} }
if accid.IsZero() { if accid.IsZero() {
@ -363,14 +365,8 @@ func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Reques
}() }()
sk := r.Header.Get("AS-X-SESSION") sk := r.Header.Get("AS-X-SESSION")
authinfo, err := ws.sessionConsumer.Query(sk) authinfo := ws.sessionConsumer.Query(sk)
if err != nil { if !authinfo.Valid() {
w.WriteHeader(http.StatusInternalServerError)
logger.Error("authorize query failed :", err)
return
}
if authinfo.Account.IsZero() || authinfo.Invalidated() {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }