세션 무효화 처리 강화

This commit is contained in:
2024-02-21 12:08:33 +09:00
parent 0c5ddac9f5
commit 0ca764d6be
7 changed files with 101 additions and 49 deletions

View File

@ -14,6 +14,7 @@ import (
type Authorization struct { type Authorization struct {
Account primitive.ObjectID `bson:"a" json:"a"` Account primitive.ObjectID `bson:"a" json:"a"`
invalidated string
// by authorization provider // by authorization provider
Platform string `bson:"p" json:"p"` Platform string `bson:"p" json:"p"`
@ -21,9 +22,34 @@ type Authorization struct {
Email string `bson:"em" json:"em"` Email string `bson:"em" json:"em"`
} }
func (auth *Authorization) ToStrings() []string {
return []string{
"a", auth.Account.Hex(),
"p", auth.Platform,
"u", auth.Uid,
"em", auth.Email,
"inv", auth.invalidated,
}
}
func (auth *Authorization) Invalidated() bool {
return len(auth.invalidated) > 0
}
func MakeAuthrizationFromStringMap(src map[string]string) Authorization {
accid, _ := primitive.ObjectIDFromHex(src["a"])
return Authorization{
Account: accid,
Platform: src["p"],
Uid: src["u"],
Email: src["em"],
invalidated: src["inv"],
}
}
type Provider interface { type Provider interface {
New(*Authorization) (string, error) New(*Authorization) (string, error)
Delete(primitive.ObjectID) error Invalidate(primitive.ObjectID) error
Query(string) (Authorization, error) Query(string) (Authorization, error)
Touch(string) (bool, error) Touch(string) (bool, error)
} }
@ -31,6 +57,7 @@ type Provider interface {
type Consumer interface { type Consumer interface {
Query(string) (Authorization, error) Query(string) (Authorization, error)
Touch(string) (Authorization, error) Touch(string) (Authorization, error)
IsInvalidated(primitive.ObjectID) bool
RegisterOnSessionInvalidated(func(primitive.ObjectID)) RegisterOnSessionInvalidated(func(primitive.ObjectID))
} }

View File

@ -10,13 +10,13 @@ import (
type cache_stage[T any] struct { type cache_stage[T any] struct {
cache map[storagekey]T cache map[storagekey]T
deleted map[storagekey]bool deleted map[storagekey]T
} }
func make_cache_stage[T any]() *cache_stage[T] { func make_cache_stage[T any]() *cache_stage[T] {
return &cache_stage[T]{ return &cache_stage[T]{
cache: make(map[storagekey]T), cache: make(map[storagekey]T),
deleted: make(map[storagekey]bool), deleted: make(map[storagekey]T),
} }
} }
@ -39,14 +39,18 @@ func (c *consumer_common[T]) add_internal(sk storagekey, si T) {
func (c *consumer_common[T]) delete_internal(sk storagekey) (old T) { func (c *consumer_common[T]) delete_internal(sk storagekey) (old T) {
if v, ok := c.stages[0].cache[sk]; ok { if v, ok := c.stages[0].cache[sk]; ok {
old = v old = v
c.stages[0].deleted[sk] = old
c.stages[1].deleted[sk] = old
delete(c.stages[0].cache, sk) delete(c.stages[0].cache, sk)
delete(c.stages[1].cache, sk) delete(c.stages[1].cache, sk)
} else if v, ok = c.stages[1].cache[sk]; ok { } else if v, ok = c.stages[1].cache[sk]; ok {
old = v old = v
c.stages[1].deleted[sk] = old
delete(c.stages[1].cache, sk) delete(c.stages[1].cache, sk)
} }
c.stages[0].deleted[sk] = true
c.stages[1].deleted[sk] = true
return return
} }

View File

@ -64,7 +64,7 @@ func (p *provider_mongo) New(input *Authorization) (string, error) {
return string(storagekey_to_publickey(sk)), err return string(storagekey_to_publickey(sk)), err
} }
func (p *provider_mongo) Delete(acc primitive.ObjectID) error { func (p *provider_mongo) Invalidate(acc primitive.ObjectID) error {
_, err := p.mongoClient.Delete(session_collection_name, bson.M{ _, err := p.mongoClient.Delete(session_collection_name, bson.M{
"_id": acc, "_id": acc,
}) })
@ -338,6 +338,11 @@ func (c *consumer_mongo) Touch(pk string) (Authorization, error) {
return *si.Auth, nil return *si.Auth, nil
} }
func (c *consumer_mongo) IsInvalidated(id primitive.ObjectID) bool {
_, ok := c.ids[id]
return !ok
}
func (c *consumer_mongo) add(sk storagekey, id primitive.ObjectID, si *sessionMongo) { func (c *consumer_mongo) add(sk storagekey, id primitive.ObjectID, si *sessionMongo) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()

View File

@ -2,7 +2,6 @@ package session
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"time" "time"
@ -43,13 +42,8 @@ func newProviderWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio
} }
func (p *provider_redis) New(input *Authorization) (string, error) { func (p *provider_redis) New(input *Authorization) (string, error) {
bt, err := json.Marshal(input)
if err != nil {
return "", err
}
sk := make_storagekey(input.Account) sk := make_storagekey(input.Account)
_, err = p.redisClient.SetEX(p.ctx, string(sk), bt, p.ttl).Result() _, err := p.redisClient.HSet(p.ctx, string(sk), input.ToStrings()).Result()
if err != nil { if err != nil {
return "", err return "", err
} }
@ -58,7 +52,7 @@ func (p *provider_redis) New(input *Authorization) (string, error) {
return string(pk), err return string(pk), err
} }
func (p *provider_redis) Delete(account primitive.ObjectID) error { func (p *provider_redis) Invalidate(account primitive.ObjectID) error {
prefix := account.Hex() prefix := account.Hex()
sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result() sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result()
if err != nil { if err != nil {
@ -67,7 +61,7 @@ func (p *provider_redis) Delete(account primitive.ObjectID) error {
} }
for _, sk := range sks { for _, sk := range sks {
p.redisClient.Del(p.ctx, sk).Result() p.redisClient.HSet(p.ctx, sk, "inv", "true")
p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result() p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result()
} }
@ -76,7 +70,7 @@ func (p *provider_redis) Delete(account primitive.ObjectID) error {
func (p *provider_redis) Query(pk string) (Authorization, error) { func (p *provider_redis) Query(pk string) (Authorization, error) {
sk := publickey_to_storagekey(publickey(pk)) sk := publickey_to_storagekey(publickey(pk))
payload, err := p.redisClient.Get(p.ctx, string(sk)).Result() src, err := p.redisClient.HGetAll(p.ctx, string(sk)).Result()
if err == redis.Nil { if err == redis.Nil {
logger.Println("session provider query :", pk, err) logger.Println("session provider query :", pk, err)
return Authorization{}, nil return Authorization{}, nil
@ -85,12 +79,7 @@ func (p *provider_redis) Query(pk string) (Authorization, error) {
return Authorization{}, err return Authorization{}, err
} }
var auth Authorization auth := MakeAuthrizationFromStringMap(src)
if err := json.Unmarshal([]byte(payload), &auth); err != nil {
logger.Println("session provider query :", pk, err)
return Authorization{}, err
}
return auth, nil return auth, nil
} }
@ -175,13 +164,13 @@ func newConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio
return consumer, nil return consumer, nil
} }
func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, bool, error) { func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, error) {
if _, deleted := c.stages[0].deleted[sk]; deleted { if old, deleted := c.stages[0].deleted[sk]; deleted {
return nil, false, nil return old, nil
} }
if _, deleted := c.stages[1].deleted[sk]; deleted { if old, deleted := c.stages[1].deleted[sk]; deleted {
return nil, false, nil return old, nil
} }
found, ok := c.stages[0].cache[sk] found, ok := c.stages[0].cache[sk]
@ -192,40 +181,41 @@ func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, bool, err
if ok { if ok {
if time.Now().Before(found.expireAt) { if time.Now().Before(found.expireAt) {
// 만료전 세션 // 만료전 세션
return found, false, nil return found, nil
} }
// 다른 Consumer가 Touch했을 수도 있으므로 redis에서 읽어본다. // 다른 Consumer가 Touch했을 수도 있으므로 redis에서 읽어본다.
} }
payload, err := c.redisClient.Get(c.ctx, string(sk)).Result() payload, err := c.redisClient.HGetAll(c.ctx, string(sk)).Result()
if err != nil && err != redis.Nil { if err != nil && err != redis.Nil {
logger.Println("consumer Query :", err) logger.Println("consumer Query :", err)
return nil, false, err return nil, err
} }
if len(payload) == 0 { if len(payload) == 0 {
return nil, false, nil return nil, nil
}
var auth Authorization
if err := json.Unmarshal([]byte(payload), &auth); err != nil {
return nil, false, err
} }
ttl, err := c.redisClient.TTL(c.ctx, string(sk)).Result() ttl, err := c.redisClient.TTL(c.ctx, string(sk)).Result()
if err != nil { if err != nil {
logger.Println("consumer Query :", err) logger.Println("consumer Query :", err)
return nil, false, err return nil, err
} }
auth := MakeAuthrizationFromStringMap(payload)
si := &sessionRedis{ si := &sessionRedis{
Authorization: &auth, Authorization: &auth,
expireAt: time.Now().Add(ttl), expireAt: time.Now().Add(ttl),
} }
c.add_internal(sk, si)
return si, true, nil if auth.Invalidated() {
c.stages[0].deleted[sk] = si
} else {
c.add_internal(sk, si)
}
return si, nil
} }
func (c *consumer_redis) Query(pk string) (Authorization, error) { func (c *consumer_redis) Query(pk string) (Authorization, error) {
@ -233,7 +223,7 @@ func (c *consumer_redis) Query(pk string) (Authorization, error) {
defer c.lock.Unlock() defer c.lock.Unlock()
sk := publickey_to_storagekey(publickey(pk)) sk := publickey_to_storagekey(publickey(pk))
si, _, err := c.query_internal(sk) si, err := c.query_internal(sk)
if err != nil { if err != nil {
logger.Println("session consumer query :", pk, err) logger.Println("session consumer query :", pk, err)
return Authorization{}, err return Authorization{}, err
@ -278,7 +268,7 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) {
if ok { if ok {
// redis에 살아있다. // redis에 살아있다.
si, added, err := c.query_internal(sk) si, err := c.query_internal(sk)
if err != nil { if err != nil {
logger.Println("session consumer touch(ok) :", pk, err) logger.Println("session consumer touch(ok) :", pk, err)
return Authorization{}, err return Authorization{}, err
@ -289,18 +279,29 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) {
return Authorization{}, nil return Authorization{}, nil
} }
if !added {
si.expireAt = time.Now().Add(c.ttl)
// stage 0으로 옮기기 위해 add_internal을 다시 부름
c.add_internal(sk, si)
}
return *si.Authorization, nil return *si.Authorization, nil
} }
return Authorization{}, nil return Authorization{}, nil
} }
func (c *consumer_redis) IsInvalidated(accid primitive.ObjectID) bool {
sk := make_storagekey(accid)
c.lock.Lock()
defer c.lock.Unlock()
if _, deleted := c.stages[0].deleted[sk]; deleted {
return true
}
if _, deleted := c.stages[1].deleted[sk]; deleted {
return true
}
return false
}
func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) { func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) {
c.onSessionInvalidated = append(c.onSessionInvalidated, cb) c.onSessionInvalidated = append(c.onSessionInvalidated, cb)
} }

View File

@ -75,7 +75,7 @@ func TestExpTable(t *testing.T) {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
pv.Delete(au1.Account) pv.Invalidate(au1.Account)
cs.Touch(sk1) cs.Touch(sk1)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)

View File

@ -576,6 +576,11 @@ func (ws *WebsocketHandler) upgrade_nosession(w http.ResponseWriter, r *http.Req
return return
} }
if authinfo.Invalidated() {
w.WriteHeader(http.StatusUnauthorized)
return
}
var upgrader = websocket.Upgrader{} // use default options var upgrader = websocket.Upgrader{} // use default options
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
@ -618,6 +623,11 @@ func (ws *WebsocketHandler) upgrade(w http.ResponseWriter, r *http.Request) {
return return
} }
if authinfo.Invalidated() {
w.WriteHeader(http.StatusUnauthorized)
return
}
var upgrader = websocket.Upgrader{} // use default options var upgrader = websocket.Upgrader{} // use default options
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {

View File

@ -368,6 +368,11 @@ func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Reques
return return
} }
if authinfo.Account.IsZero() || authinfo.Invalidated() {
w.WriteHeader(http.StatusUnauthorized)
return
}
var upgrader = websocket.Upgrader{} // use default options var upgrader = websocket.Upgrader{} // use default options
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {