From 46f7d358ed3658de78874512a5fca6792a9b9b43 Mon Sep 17 00:00:00 2001 From: mountain Date: Mon, 25 Dec 2023 22:06:57 +0900 Subject: [PATCH] =?UTF-8?q?=EC=84=B8=EC=85=98=20=ED=95=B4=EC=A0=9C=20?= =?UTF-8?q?=EC=BD=9C=EB=B0=B1=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- session/common.go | 1 + session/consumer_common.go | 30 +++++++++++++++-------- session/impl_mongo.go | 21 +++++++++++++---- session/impl_redis.go | 11 ++++++++- wshandler/wshandler.go | 27 ++++++++++++--------- wshandler/wshandler_peer.go | 47 ++++++++++++++++++++++++++++++++++++- 6 files changed, 110 insertions(+), 27 deletions(-) diff --git a/session/common.go b/session/common.go index 8693428..8ab18c6 100644 --- a/session/common.go +++ b/session/common.go @@ -31,6 +31,7 @@ type Provider interface { type Consumer interface { Query(string) (Authorization, error) Touch(string) (Authorization, error) + RegisterOnSessionInvalidated(func(primitive.ObjectID)) } type storagekey string diff --git a/session/consumer_common.go b/session/consumer_common.go index 3fc261a..770f45c 100644 --- a/session/consumer_common.go +++ b/session/consumer_common.go @@ -4,6 +4,8 @@ import ( "context" "sync" "time" + + "go.mongodb.org/mongo-driver/bson/primitive" ) type cache_stage[T any] struct { @@ -19,11 +21,12 @@ func make_cache_stage[T any]() *cache_stage[T] { } type consumer_common[T any] struct { - lock sync.Mutex - ttl time.Duration - ctx context.Context - stages [2]*cache_stage[T] - startTime time.Time + lock sync.Mutex + ttl time.Duration + ctx context.Context + stages [2]*cache_stage[T] + startTime time.Time + onSessionInvalidated []func(primitive.ObjectID) } func (c *consumer_common[T]) add_internal(sk storagekey, si T) { @@ -33,18 +36,25 @@ func (c *consumer_common[T]) add_internal(sk storagekey, si T) { delete(c.stages[1].deleted, sk) } -func (c *consumer_common[T]) delete_internal(sk storagekey) { - delete(c.stages[0].cache, sk) +func (c *consumer_common[T]) delete_internal(sk storagekey) (old T) { + if v, ok := c.stages[0].cache[sk]; ok { + old = v + delete(c.stages[0].cache, sk) + delete(c.stages[1].cache, sk) + } else if v, ok = c.stages[1].cache[sk]; ok { + old = v + delete(c.stages[1].cache, sk) + } c.stages[0].deleted[sk] = true - delete(c.stages[1].cache, sk) c.stages[1].deleted[sk] = true + return } -func (c *consumer_common[T]) delete(sk storagekey) { +func (c *consumer_common[T]) delete(sk storagekey) T { c.lock.Lock() defer c.lock.Unlock() - c.delete_internal(sk) + return c.delete_internal(sk) } func (c *consumer_common[T]) changeStage() { diff --git a/session/impl_mongo.go b/session/impl_mongo.go index f9713c0..7ced1a0 100644 --- a/session/impl_mongo.go +++ b/session/impl_mongo.go @@ -182,12 +182,20 @@ func newConsumerWithMongo(ctx context.Context, mongoUrl string, ttl time.Duratio consumer.add(data.Session.Key, data.DocumentKey.Id, data.Session) case "update": if data.Session == nil { - consumer.deleteById(data.DocumentKey.Id) + if old := consumer.deleteById(data.DocumentKey.Id); old != nil { + for _, f := range consumer.onSessionInvalidated { + f(old.Auth.Account) + } + } } else { consumer.add(data.Session.Key, data.DocumentKey.Id, data.Session) } case "delete": - consumer.deleteById(data.DocumentKey.Id) + if old := consumer.deleteById(data.DocumentKey.Id); old != nil { + for _, f := range consumer.onSessionInvalidated { + f(old.Auth.Account) + } + } } } else { logger.Error("watchAuthCollection stream.Decode failed :", err) @@ -338,12 +346,17 @@ func (c *consumer_mongo) add(sk storagekey, id primitive.ObjectID, si *sessionMo c.ids[id] = sk } -func (c *consumer_mongo) deleteById(id primitive.ObjectID) { +func (c *consumer_mongo) deleteById(id primitive.ObjectID) (old *sessionMongo) { c.lock.Lock() defer c.lock.Unlock() if sk, ok := c.ids[id]; ok { - c.consumer_common.delete_internal(sk) + old = c.consumer_common.delete_internal(sk) delete(c.ids, id) } + return +} + +func (c *consumer_mongo) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) { + c.onSessionInvalidated = append(c.onSessionInvalidated, cb) } diff --git a/session/impl_redis.go b/session/impl_redis.go index e3169cc..e38bc4b 100644 --- a/session/impl_redis.go +++ b/session/impl_redis.go @@ -161,7 +161,12 @@ func newConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio switch msg.Channel { case deleteChannel: sk := storagekey(msg.Payload) - consumer.delete(sk) + old := consumer.delete(sk) + if old != nil { + for _, f := range consumer.onSessionInvalidated { + f(old.Account) + } + } } } } @@ -286,3 +291,7 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) { return Authorization{}, nil } + +func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) { + c.onSessionInvalidated = append(c.onSessionInvalidated, cb) +} diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index fcaf7bd..f9c47b4 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -86,9 +86,7 @@ type send_msg_queue_elem struct { msg []byte } -type WebsocketHandler struct { - WebsocketApiBroker - +type websocketHandlerBase struct { redisMsgChanName string redisCmdChanName string redisSync *redis.Client @@ -101,6 +99,11 @@ type WebsocketHandler struct { sessionConsumer session.Consumer } +type WebsocketHandler struct { + WebsocketApiBroker + websocketHandlerBase +} + type wsConfig struct { gocommon.StorageAddr `json:"storage"` } @@ -152,14 +155,16 @@ func NewWebsocketHandler(consumer session.Consumer, redisUrl string) (*Websocket }() return &WebsocketHandler{ - redisMsgChanName: fmt.Sprintf("_wsh_msg_%d", redisSync.Options().DB), - redisCmdChanName: fmt.Sprintf("_wsh_cmd_%d", redisSync.Options().DB), - redisSync: redisSync, - connInOutChan: make(chan *wsconn), - deliveryChan: make(chan any, 1000), - localDeliveryChan: make(chan any, 100), - sendMsgChan: sendchan, - sessionConsumer: consumer, + websocketHandlerBase: websocketHandlerBase{ + redisMsgChanName: fmt.Sprintf("_wsh_msg_%d", redisSync.Options().DB), + redisCmdChanName: fmt.Sprintf("_wsh_cmd_%d", redisSync.Options().DB), + redisSync: redisSync, + connInOutChan: make(chan *wsconn), + deliveryChan: make(chan any, 1000), + localDeliveryChan: make(chan any, 100), + sendMsgChan: sendchan, + sessionConsumer: consumer, + }, }, nil } diff --git a/wshandler/wshandler_peer.go b/wshandler/wshandler_peer.go index ee2b6f2..bc300bc 100644 --- a/wshandler/wshandler_peer.go +++ b/wshandler/wshandler_peer.go @@ -22,10 +22,22 @@ type WebsocketPeerHandler interface { RegisterHandlers(serveMux *http.ServeMux, prefix string) error } +type connEstChannelValue struct { + accid primitive.ObjectID + conn *websocket.Conn +} + +type connDisChannelValue struct { + accid primitive.ObjectID + closed bool +} + type websocketPeerHandler[T PeerInterface] struct { methods map[string]peerApiFuncType[T] createPeer func(primitive.ObjectID) T sessionConsumer session.Consumer + connEstChannel chan connEstChannelValue + connDisChannel chan connDisChannelValue } type PeerInterface interface { @@ -140,14 +152,28 @@ func NewWebsocketPeerHandler[T PeerInterface](consumer session.Consumer, creator methods[k] = v } - return &websocketPeerHandler[T]{ + wsh := &websocketPeerHandler[T]{ sessionConsumer: consumer, methods: methods, createPeer: creator, + connEstChannel: make(chan connEstChannelValue), + connDisChannel: make(chan connDisChannelValue), + } + + consumer.RegisterOnSessionInvalidated(wsh.onSessionInvalidated) + return wsh +} + +func (ws *websocketPeerHandler[T]) onSessionInvalidated(accid primitive.ObjectID) { + ws.connDisChannel <- connDisChannelValue{ + accid: accid, + closed: false, } } func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux *http.ServeMux, prefix string) error { + go ws.sessionMonitoring() + if *noAuthFlag { serveMux.HandleFunc(prefix, ws.upgrade_nosession) } else { @@ -157,14 +183,33 @@ func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux *http.ServeMux, pre return nil } +func (ws *websocketPeerHandler[T]) sessionMonitoring() { + all := make(map[primitive.ObjectID]*websocket.Conn) + for { + select { + case estVal := <-ws.connEstChannel: + all[estVal.accid] = estVal.conn + case disVal := <-ws.connDisChannel: + if disVal.closed { + delete(all, disVal.accid) + } else if c := all[disVal.accid]; c != nil { + c.Close() + delete(all, disVal.accid) + } + } + } +} + func (ws *websocketPeerHandler[T]) upgrade_core(conn *websocket.Conn, accid primitive.ObjectID, nonce uint32) { go func(c *websocket.Conn, accid primitive.ObjectID) { peer := ws.createPeer(accid) var closeReason string peer.ClientConnected(conn) + ws.connEstChannel <- connEstChannelValue{accid: accid, conn: conn} defer func() { + ws.connDisChannel <- connDisChannelValue{accid: accid, closed: true} peer.ClientDisconnected(closeReason) }()