diff --git a/session/common.go b/session/common.go index 8ab18c6..d0ca336 100644 --- a/session/common.go +++ b/session/common.go @@ -13,7 +13,8 @@ import ( ) type Authorization struct { - Account primitive.ObjectID `bson:"a" json:"a"` + Account primitive.ObjectID `bson:"a" json:"a"` + invalidated string // by authorization provider Platform string `bson:"p" json:"p"` @@ -21,9 +22,34 @@ type Authorization struct { Email string `bson:"em" json:"em"` } +func (auth *Authorization) ToStrings() []string { + return []string{ + "a", auth.Account.Hex(), + "p", auth.Platform, + "u", auth.Uid, + "em", auth.Email, + "inv", auth.invalidated, + } +} + +func (auth *Authorization) Invalidated() bool { + return len(auth.invalidated) > 0 +} + +func MakeAuthrizationFromStringMap(src map[string]string) Authorization { + accid, _ := primitive.ObjectIDFromHex(src["a"]) + return Authorization{ + Account: accid, + Platform: src["p"], + Uid: src["u"], + Email: src["em"], + invalidated: src["inv"], + } +} + type Provider interface { New(*Authorization) (string, error) - Delete(primitive.ObjectID) error + Invalidate(primitive.ObjectID) error Query(string) (Authorization, error) Touch(string) (bool, error) } @@ -31,6 +57,7 @@ type Provider interface { type Consumer interface { Query(string) (Authorization, error) Touch(string) (Authorization, error) + IsInvalidated(primitive.ObjectID) bool RegisterOnSessionInvalidated(func(primitive.ObjectID)) } diff --git a/session/consumer_common.go b/session/consumer_common.go index 770f45c..1932cfe 100644 --- a/session/consumer_common.go +++ b/session/consumer_common.go @@ -10,13 +10,13 @@ import ( type cache_stage[T any] struct { cache map[storagekey]T - deleted map[storagekey]bool + deleted map[storagekey]T } func make_cache_stage[T any]() *cache_stage[T] { return &cache_stage[T]{ cache: make(map[storagekey]T), - deleted: make(map[storagekey]bool), + deleted: make(map[storagekey]T), } } @@ -39,14 +39,18 @@ func (c *consumer_common[T]) add_internal(sk storagekey, si T) { func (c *consumer_common[T]) delete_internal(sk storagekey) (old T) { if v, ok := c.stages[0].cache[sk]; ok { old = v + c.stages[0].deleted[sk] = old + c.stages[1].deleted[sk] = old + delete(c.stages[0].cache, sk) delete(c.stages[1].cache, sk) } else if v, ok = c.stages[1].cache[sk]; ok { old = v + c.stages[1].deleted[sk] = old + delete(c.stages[1].cache, sk) } - c.stages[0].deleted[sk] = true - c.stages[1].deleted[sk] = true + return } diff --git a/session/impl_mongo.go b/session/impl_mongo.go index 7ced1a0..6823a62 100644 --- a/session/impl_mongo.go +++ b/session/impl_mongo.go @@ -64,7 +64,7 @@ func (p *provider_mongo) New(input *Authorization) (string, error) { return string(storagekey_to_publickey(sk)), err } -func (p *provider_mongo) Delete(acc primitive.ObjectID) error { +func (p *provider_mongo) Invalidate(acc primitive.ObjectID) error { _, err := p.mongoClient.Delete(session_collection_name, bson.M{ "_id": acc, }) @@ -338,6 +338,11 @@ func (c *consumer_mongo) Touch(pk string) (Authorization, error) { return *si.Auth, nil } +func (c *consumer_mongo) IsInvalidated(id primitive.ObjectID) bool { + _, ok := c.ids[id] + return !ok +} + func (c *consumer_mongo) add(sk storagekey, id primitive.ObjectID, si *sessionMongo) { c.lock.Lock() defer c.lock.Unlock() diff --git a/session/impl_redis.go b/session/impl_redis.go index 8bbe328..cc0b9af 100644 --- a/session/impl_redis.go +++ b/session/impl_redis.go @@ -2,7 +2,6 @@ package session import ( "context" - "encoding/json" "fmt" "time" @@ -43,13 +42,8 @@ func newProviderWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio } func (p *provider_redis) New(input *Authorization) (string, error) { - bt, err := json.Marshal(input) - if err != nil { - return "", err - } - sk := make_storagekey(input.Account) - _, err = p.redisClient.SetEX(p.ctx, string(sk), bt, p.ttl).Result() + _, err := p.redisClient.HSet(p.ctx, string(sk), input.ToStrings()).Result() if err != nil { return "", err } @@ -58,7 +52,7 @@ func (p *provider_redis) New(input *Authorization) (string, error) { return string(pk), err } -func (p *provider_redis) Delete(account primitive.ObjectID) error { +func (p *provider_redis) Invalidate(account primitive.ObjectID) error { prefix := account.Hex() sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result() if err != nil { @@ -67,7 +61,7 @@ func (p *provider_redis) Delete(account primitive.ObjectID) error { } for _, sk := range sks { - p.redisClient.Del(p.ctx, sk).Result() + p.redisClient.HSet(p.ctx, sk, "inv", "true") p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result() } @@ -76,7 +70,7 @@ func (p *provider_redis) Delete(account primitive.ObjectID) error { func (p *provider_redis) Query(pk string) (Authorization, error) { sk := publickey_to_storagekey(publickey(pk)) - payload, err := p.redisClient.Get(p.ctx, string(sk)).Result() + src, err := p.redisClient.HGetAll(p.ctx, string(sk)).Result() if err == redis.Nil { logger.Println("session provider query :", pk, err) return Authorization{}, nil @@ -85,12 +79,7 @@ func (p *provider_redis) Query(pk string) (Authorization, error) { return Authorization{}, err } - var auth Authorization - if err := json.Unmarshal([]byte(payload), &auth); err != nil { - logger.Println("session provider query :", pk, err) - return Authorization{}, err - } - + auth := MakeAuthrizationFromStringMap(src) return auth, nil } @@ -175,13 +164,13 @@ func newConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio return consumer, nil } -func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, bool, error) { - if _, deleted := c.stages[0].deleted[sk]; deleted { - return nil, false, nil +func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, error) { + if old, deleted := c.stages[0].deleted[sk]; deleted { + return old, nil } - if _, deleted := c.stages[1].deleted[sk]; deleted { - return nil, false, nil + if old, deleted := c.stages[1].deleted[sk]; deleted { + return old, nil } found, ok := c.stages[0].cache[sk] @@ -192,40 +181,41 @@ func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, bool, err if ok { if time.Now().Before(found.expireAt) { // 만료전 세션 - return found, false, nil + return found, nil } // 다른 Consumer가 Touch했을 수도 있으므로 redis에서 읽어본다. } - payload, err := c.redisClient.Get(c.ctx, string(sk)).Result() + payload, err := c.redisClient.HGetAll(c.ctx, string(sk)).Result() if err != nil && err != redis.Nil { logger.Println("consumer Query :", err) - return nil, false, err + return nil, err } if len(payload) == 0 { - return nil, false, nil - } - - var auth Authorization - if err := json.Unmarshal([]byte(payload), &auth); err != nil { - return nil, false, err + return nil, nil } ttl, err := c.redisClient.TTL(c.ctx, string(sk)).Result() if err != nil { logger.Println("consumer Query :", err) - return nil, false, err + return nil, err } + auth := MakeAuthrizationFromStringMap(payload) si := &sessionRedis{ Authorization: &auth, expireAt: time.Now().Add(ttl), } - c.add_internal(sk, si) - return si, true, nil + if auth.Invalidated() { + c.stages[0].deleted[sk] = si + } else { + c.add_internal(sk, si) + } + + return si, nil } func (c *consumer_redis) Query(pk string) (Authorization, error) { @@ -233,7 +223,7 @@ func (c *consumer_redis) Query(pk string) (Authorization, error) { defer c.lock.Unlock() sk := publickey_to_storagekey(publickey(pk)) - si, _, err := c.query_internal(sk) + si, err := c.query_internal(sk) if err != nil { logger.Println("session consumer query :", pk, err) return Authorization{}, err @@ -278,7 +268,7 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) { if ok { // redis에 살아있다. - si, added, err := c.query_internal(sk) + si, err := c.query_internal(sk) if err != nil { logger.Println("session consumer touch(ok) :", pk, err) return Authorization{}, err @@ -289,18 +279,29 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) { return Authorization{}, nil } - if !added { - si.expireAt = time.Now().Add(c.ttl) - // stage 0으로 옮기기 위해 add_internal을 다시 부름 - c.add_internal(sk, si) - } - return *si.Authorization, nil } return Authorization{}, nil } +func (c *consumer_redis) IsInvalidated(accid primitive.ObjectID) bool { + sk := make_storagekey(accid) + + c.lock.Lock() + defer c.lock.Unlock() + + if _, deleted := c.stages[0].deleted[sk]; deleted { + return true + } + + if _, deleted := c.stages[1].deleted[sk]; deleted { + return true + } + + return false +} + func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) { c.onSessionInvalidated = append(c.onSessionInvalidated, cb) } diff --git a/session/session_test.go b/session/session_test.go index 23987de..3cc6867 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -75,7 +75,7 @@ func TestExpTable(t *testing.T) { time.Sleep(2 * time.Second) time.Sleep(2 * time.Second) - pv.Delete(au1.Account) + pv.Invalidate(au1.Account) cs.Touch(sk1) time.Sleep(2 * time.Second) diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index 887daf0..30e9080 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -576,6 +576,11 @@ func (ws *WebsocketHandler) upgrade_nosession(w http.ResponseWriter, r *http.Req return } + if authinfo.Invalidated() { + w.WriteHeader(http.StatusUnauthorized) + return + } + var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -618,6 +623,11 @@ func (ws *WebsocketHandler) upgrade(w http.ResponseWriter, r *http.Request) { return } + if authinfo.Invalidated() { + w.WriteHeader(http.StatusUnauthorized) + return + } + var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil { diff --git a/wshandler/wshandler_peer.go b/wshandler/wshandler_peer.go index 111a829..942f5ae 100644 --- a/wshandler/wshandler_peer.go +++ b/wshandler/wshandler_peer.go @@ -368,6 +368,11 @@ func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Reques return } + if authinfo.Account.IsZero() || authinfo.Invalidated() { + w.WriteHeader(http.StatusUnauthorized) + return + } + var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil {