세션 무효화 처리 강화
This commit is contained in:
@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
type Authorization struct {
|
type Authorization struct {
|
||||||
Account primitive.ObjectID `bson:"a" json:"a"`
|
Account primitive.ObjectID `bson:"a" json:"a"`
|
||||||
|
invalidated string
|
||||||
|
|
||||||
// by authorization provider
|
// by authorization provider
|
||||||
Platform string `bson:"p" json:"p"`
|
Platform string `bson:"p" json:"p"`
|
||||||
@ -21,9 +22,34 @@ type Authorization struct {
|
|||||||
Email string `bson:"em" json:"em"`
|
Email string `bson:"em" json:"em"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *Authorization) ToStrings() []string {
|
||||||
|
return []string{
|
||||||
|
"a", auth.Account.Hex(),
|
||||||
|
"p", auth.Platform,
|
||||||
|
"u", auth.Uid,
|
||||||
|
"em", auth.Email,
|
||||||
|
"inv", auth.invalidated,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *Authorization) Invalidated() bool {
|
||||||
|
return len(auth.invalidated) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeAuthrizationFromStringMap(src map[string]string) Authorization {
|
||||||
|
accid, _ := primitive.ObjectIDFromHex(src["a"])
|
||||||
|
return Authorization{
|
||||||
|
Account: accid,
|
||||||
|
Platform: src["p"],
|
||||||
|
Uid: src["u"],
|
||||||
|
Email: src["em"],
|
||||||
|
invalidated: src["inv"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
New(*Authorization) (string, error)
|
New(*Authorization) (string, error)
|
||||||
Delete(primitive.ObjectID) error
|
Invalidate(primitive.ObjectID) error
|
||||||
Query(string) (Authorization, error)
|
Query(string) (Authorization, error)
|
||||||
Touch(string) (bool, error)
|
Touch(string) (bool, error)
|
||||||
}
|
}
|
||||||
@ -31,6 +57,7 @@ type Provider interface {
|
|||||||
type Consumer interface {
|
type Consumer interface {
|
||||||
Query(string) (Authorization, error)
|
Query(string) (Authorization, error)
|
||||||
Touch(string) (Authorization, error)
|
Touch(string) (Authorization, error)
|
||||||
|
IsInvalidated(primitive.ObjectID) bool
|
||||||
RegisterOnSessionInvalidated(func(primitive.ObjectID))
|
RegisterOnSessionInvalidated(func(primitive.ObjectID))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,13 +10,13 @@ import (
|
|||||||
|
|
||||||
type cache_stage[T any] struct {
|
type cache_stage[T any] struct {
|
||||||
cache map[storagekey]T
|
cache map[storagekey]T
|
||||||
deleted map[storagekey]bool
|
deleted map[storagekey]T
|
||||||
}
|
}
|
||||||
|
|
||||||
func make_cache_stage[T any]() *cache_stage[T] {
|
func make_cache_stage[T any]() *cache_stage[T] {
|
||||||
return &cache_stage[T]{
|
return &cache_stage[T]{
|
||||||
cache: make(map[storagekey]T),
|
cache: make(map[storagekey]T),
|
||||||
deleted: make(map[storagekey]bool),
|
deleted: make(map[storagekey]T),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,14 +39,18 @@ func (c *consumer_common[T]) add_internal(sk storagekey, si T) {
|
|||||||
func (c *consumer_common[T]) delete_internal(sk storagekey) (old T) {
|
func (c *consumer_common[T]) delete_internal(sk storagekey) (old T) {
|
||||||
if v, ok := c.stages[0].cache[sk]; ok {
|
if v, ok := c.stages[0].cache[sk]; ok {
|
||||||
old = v
|
old = v
|
||||||
|
c.stages[0].deleted[sk] = old
|
||||||
|
c.stages[1].deleted[sk] = old
|
||||||
|
|
||||||
delete(c.stages[0].cache, sk)
|
delete(c.stages[0].cache, sk)
|
||||||
delete(c.stages[1].cache, sk)
|
delete(c.stages[1].cache, sk)
|
||||||
} else if v, ok = c.stages[1].cache[sk]; ok {
|
} else if v, ok = c.stages[1].cache[sk]; ok {
|
||||||
old = v
|
old = v
|
||||||
|
c.stages[1].deleted[sk] = old
|
||||||
|
|
||||||
delete(c.stages[1].cache, sk)
|
delete(c.stages[1].cache, sk)
|
||||||
}
|
}
|
||||||
c.stages[0].deleted[sk] = true
|
|
||||||
c.stages[1].deleted[sk] = true
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -64,7 +64,7 @@ func (p *provider_mongo) New(input *Authorization) (string, error) {
|
|||||||
return string(storagekey_to_publickey(sk)), err
|
return string(storagekey_to_publickey(sk)), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *provider_mongo) Delete(acc primitive.ObjectID) error {
|
func (p *provider_mongo) Invalidate(acc primitive.ObjectID) error {
|
||||||
_, err := p.mongoClient.Delete(session_collection_name, bson.M{
|
_, err := p.mongoClient.Delete(session_collection_name, bson.M{
|
||||||
"_id": acc,
|
"_id": acc,
|
||||||
})
|
})
|
||||||
@ -338,6 +338,11 @@ func (c *consumer_mongo) Touch(pk string) (Authorization, error) {
|
|||||||
return *si.Auth, nil
|
return *si.Auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *consumer_mongo) IsInvalidated(id primitive.ObjectID) bool {
|
||||||
|
_, ok := c.ids[id]
|
||||||
|
return !ok
|
||||||
|
}
|
||||||
|
|
||||||
func (c *consumer_mongo) add(sk storagekey, id primitive.ObjectID, si *sessionMongo) {
|
func (c *consumer_mongo) add(sk storagekey, id primitive.ObjectID, si *sessionMongo) {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
|
|||||||
@ -2,7 +2,6 @@ package session
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -43,13 +42,8 @@ func newProviderWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *provider_redis) New(input *Authorization) (string, error) {
|
func (p *provider_redis) New(input *Authorization) (string, error) {
|
||||||
bt, err := json.Marshal(input)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
sk := make_storagekey(input.Account)
|
sk := make_storagekey(input.Account)
|
||||||
_, err = p.redisClient.SetEX(p.ctx, string(sk), bt, p.ttl).Result()
|
_, err := p.redisClient.HSet(p.ctx, string(sk), input.ToStrings()).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -58,7 +52,7 @@ func (p *provider_redis) New(input *Authorization) (string, error) {
|
|||||||
return string(pk), err
|
return string(pk), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *provider_redis) Delete(account primitive.ObjectID) error {
|
func (p *provider_redis) Invalidate(account primitive.ObjectID) error {
|
||||||
prefix := account.Hex()
|
prefix := account.Hex()
|
||||||
sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result()
|
sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -67,7 +61,7 @@ func (p *provider_redis) Delete(account primitive.ObjectID) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, sk := range sks {
|
for _, sk := range sks {
|
||||||
p.redisClient.Del(p.ctx, sk).Result()
|
p.redisClient.HSet(p.ctx, sk, "inv", "true")
|
||||||
p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result()
|
p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,7 +70,7 @@ func (p *provider_redis) Delete(account primitive.ObjectID) error {
|
|||||||
|
|
||||||
func (p *provider_redis) Query(pk string) (Authorization, error) {
|
func (p *provider_redis) Query(pk string) (Authorization, error) {
|
||||||
sk := publickey_to_storagekey(publickey(pk))
|
sk := publickey_to_storagekey(publickey(pk))
|
||||||
payload, err := p.redisClient.Get(p.ctx, string(sk)).Result()
|
src, err := p.redisClient.HGetAll(p.ctx, string(sk)).Result()
|
||||||
if err == redis.Nil {
|
if err == redis.Nil {
|
||||||
logger.Println("session provider query :", pk, err)
|
logger.Println("session provider query :", pk, err)
|
||||||
return Authorization{}, nil
|
return Authorization{}, nil
|
||||||
@ -85,12 +79,7 @@ func (p *provider_redis) Query(pk string) (Authorization, error) {
|
|||||||
return Authorization{}, err
|
return Authorization{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var auth Authorization
|
auth := MakeAuthrizationFromStringMap(src)
|
||||||
if err := json.Unmarshal([]byte(payload), &auth); err != nil {
|
|
||||||
logger.Println("session provider query :", pk, err)
|
|
||||||
return Authorization{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,13 +164,13 @@ func newConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio
|
|||||||
return consumer, nil
|
return consumer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, bool, error) {
|
func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, error) {
|
||||||
if _, deleted := c.stages[0].deleted[sk]; deleted {
|
if old, deleted := c.stages[0].deleted[sk]; deleted {
|
||||||
return nil, false, nil
|
return old, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, deleted := c.stages[1].deleted[sk]; deleted {
|
if old, deleted := c.stages[1].deleted[sk]; deleted {
|
||||||
return nil, false, nil
|
return old, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
found, ok := c.stages[0].cache[sk]
|
found, ok := c.stages[0].cache[sk]
|
||||||
@ -192,40 +181,41 @@ func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, bool, err
|
|||||||
if ok {
|
if ok {
|
||||||
if time.Now().Before(found.expireAt) {
|
if time.Now().Before(found.expireAt) {
|
||||||
// 만료전 세션
|
// 만료전 세션
|
||||||
return found, false, nil
|
return found, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 다른 Consumer가 Touch했을 수도 있으므로 redis에서 읽어본다.
|
// 다른 Consumer가 Touch했을 수도 있으므로 redis에서 읽어본다.
|
||||||
}
|
}
|
||||||
|
|
||||||
payload, err := c.redisClient.Get(c.ctx, string(sk)).Result()
|
payload, err := c.redisClient.HGetAll(c.ctx, string(sk)).Result()
|
||||||
if err != nil && err != redis.Nil {
|
if err != nil && err != redis.Nil {
|
||||||
logger.Println("consumer Query :", err)
|
logger.Println("consumer Query :", err)
|
||||||
return nil, false, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(payload) == 0 {
|
if len(payload) == 0 {
|
||||||
return nil, false, nil
|
return nil, nil
|
||||||
}
|
|
||||||
|
|
||||||
var auth Authorization
|
|
||||||
if err := json.Unmarshal([]byte(payload), &auth); err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ttl, err := c.redisClient.TTL(c.ctx, string(sk)).Result()
|
ttl, err := c.redisClient.TTL(c.ctx, string(sk)).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Println("consumer Query :", err)
|
logger.Println("consumer Query :", err)
|
||||||
return nil, false, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auth := MakeAuthrizationFromStringMap(payload)
|
||||||
si := &sessionRedis{
|
si := &sessionRedis{
|
||||||
Authorization: &auth,
|
Authorization: &auth,
|
||||||
expireAt: time.Now().Add(ttl),
|
expireAt: time.Now().Add(ttl),
|
||||||
}
|
}
|
||||||
c.add_internal(sk, si)
|
|
||||||
|
|
||||||
return si, true, nil
|
if auth.Invalidated() {
|
||||||
|
c.stages[0].deleted[sk] = si
|
||||||
|
} else {
|
||||||
|
c.add_internal(sk, si)
|
||||||
|
}
|
||||||
|
|
||||||
|
return si, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *consumer_redis) Query(pk string) (Authorization, error) {
|
func (c *consumer_redis) Query(pk string) (Authorization, error) {
|
||||||
@ -233,7 +223,7 @@ func (c *consumer_redis) Query(pk string) (Authorization, error) {
|
|||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
sk := publickey_to_storagekey(publickey(pk))
|
sk := publickey_to_storagekey(publickey(pk))
|
||||||
si, _, err := c.query_internal(sk)
|
si, err := c.query_internal(sk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Println("session consumer query :", pk, err)
|
logger.Println("session consumer query :", pk, err)
|
||||||
return Authorization{}, err
|
return Authorization{}, err
|
||||||
@ -278,7 +268,7 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) {
|
|||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
// redis에 살아있다.
|
// redis에 살아있다.
|
||||||
si, added, err := c.query_internal(sk)
|
si, err := c.query_internal(sk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Println("session consumer touch(ok) :", pk, err)
|
logger.Println("session consumer touch(ok) :", pk, err)
|
||||||
return Authorization{}, err
|
return Authorization{}, err
|
||||||
@ -289,18 +279,29 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) {
|
|||||||
return Authorization{}, nil
|
return Authorization{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !added {
|
|
||||||
si.expireAt = time.Now().Add(c.ttl)
|
|
||||||
// stage 0으로 옮기기 위해 add_internal을 다시 부름
|
|
||||||
c.add_internal(sk, si)
|
|
||||||
}
|
|
||||||
|
|
||||||
return *si.Authorization, nil
|
return *si.Authorization, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return Authorization{}, nil
|
return Authorization{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *consumer_redis) IsInvalidated(accid primitive.ObjectID) bool {
|
||||||
|
sk := make_storagekey(accid)
|
||||||
|
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
|
if _, deleted := c.stages[0].deleted[sk]; deleted {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, deleted := c.stages[1].deleted[sk]; deleted {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) {
|
func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) {
|
||||||
c.onSessionInvalidated = append(c.onSessionInvalidated, cb)
|
c.onSessionInvalidated = append(c.onSessionInvalidated, cb)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -75,7 +75,7 @@ func TestExpTable(t *testing.T) {
|
|||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
pv.Delete(au1.Account)
|
pv.Invalidate(au1.Account)
|
||||||
|
|
||||||
cs.Touch(sk1)
|
cs.Touch(sk1)
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
|
|||||||
@ -576,6 +576,11 @@ func (ws *WebsocketHandler) upgrade_nosession(w http.ResponseWriter, r *http.Req
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if authinfo.Invalidated() {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{} // use default options
|
var upgrader = websocket.Upgrader{} // use default options
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -618,6 +623,11 @@ func (ws *WebsocketHandler) upgrade(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if authinfo.Invalidated() {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{} // use default options
|
var upgrader = websocket.Upgrader{} // use default options
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -368,6 +368,11 @@ func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Reques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if authinfo.Account.IsZero() || authinfo.Invalidated() {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{} // use default options
|
var upgrader = websocket.Upgrader{} // use default options
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user