세션 해제 콜백 추가

This commit is contained in:
2023-12-25 22:06:57 +09:00
parent 08802176cb
commit 46f7d358ed
6 changed files with 110 additions and 27 deletions

View File

@ -31,6 +31,7 @@ type Provider interface {
type Consumer interface { type Consumer interface {
Query(string) (Authorization, error) Query(string) (Authorization, error)
Touch(string) (Authorization, error) Touch(string) (Authorization, error)
RegisterOnSessionInvalidated(func(primitive.ObjectID))
} }
type storagekey string type storagekey string

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"sync" "sync"
"time" "time"
"go.mongodb.org/mongo-driver/bson/primitive"
) )
type cache_stage[T any] struct { type cache_stage[T any] struct {
@ -19,11 +21,12 @@ func make_cache_stage[T any]() *cache_stage[T] {
} }
type consumer_common[T any] struct { type consumer_common[T any] struct {
lock sync.Mutex lock sync.Mutex
ttl time.Duration ttl time.Duration
ctx context.Context ctx context.Context
stages [2]*cache_stage[T] stages [2]*cache_stage[T]
startTime time.Time startTime time.Time
onSessionInvalidated []func(primitive.ObjectID)
} }
func (c *consumer_common[T]) add_internal(sk storagekey, si T) { func (c *consumer_common[T]) add_internal(sk storagekey, si T) {
@ -33,18 +36,25 @@ func (c *consumer_common[T]) add_internal(sk storagekey, si T) {
delete(c.stages[1].deleted, sk) delete(c.stages[1].deleted, sk)
} }
func (c *consumer_common[T]) delete_internal(sk storagekey) { func (c *consumer_common[T]) delete_internal(sk storagekey) (old T) {
delete(c.stages[0].cache, sk) if v, ok := c.stages[0].cache[sk]; ok {
old = v
delete(c.stages[0].cache, sk)
delete(c.stages[1].cache, sk)
} else if v, ok = c.stages[1].cache[sk]; ok {
old = v
delete(c.stages[1].cache, sk)
}
c.stages[0].deleted[sk] = true c.stages[0].deleted[sk] = true
delete(c.stages[1].cache, sk)
c.stages[1].deleted[sk] = true c.stages[1].deleted[sk] = true
return
} }
func (c *consumer_common[T]) delete(sk storagekey) { func (c *consumer_common[T]) delete(sk storagekey) T {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
c.delete_internal(sk) return c.delete_internal(sk)
} }
func (c *consumer_common[T]) changeStage() { func (c *consumer_common[T]) changeStage() {

View File

@ -182,12 +182,20 @@ func newConsumerWithMongo(ctx context.Context, mongoUrl string, ttl time.Duratio
consumer.add(data.Session.Key, data.DocumentKey.Id, data.Session) consumer.add(data.Session.Key, data.DocumentKey.Id, data.Session)
case "update": case "update":
if data.Session == nil { if data.Session == nil {
consumer.deleteById(data.DocumentKey.Id) if old := consumer.deleteById(data.DocumentKey.Id); old != nil {
for _, f := range consumer.onSessionInvalidated {
f(old.Auth.Account)
}
}
} else { } else {
consumer.add(data.Session.Key, data.DocumentKey.Id, data.Session) consumer.add(data.Session.Key, data.DocumentKey.Id, data.Session)
} }
case "delete": case "delete":
consumer.deleteById(data.DocumentKey.Id) if old := consumer.deleteById(data.DocumentKey.Id); old != nil {
for _, f := range consumer.onSessionInvalidated {
f(old.Auth.Account)
}
}
} }
} else { } else {
logger.Error("watchAuthCollection stream.Decode failed :", err) logger.Error("watchAuthCollection stream.Decode failed :", err)
@ -338,12 +346,17 @@ func (c *consumer_mongo) add(sk storagekey, id primitive.ObjectID, si *sessionMo
c.ids[id] = sk c.ids[id] = sk
} }
func (c *consumer_mongo) deleteById(id primitive.ObjectID) { func (c *consumer_mongo) deleteById(id primitive.ObjectID) (old *sessionMongo) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
if sk, ok := c.ids[id]; ok { if sk, ok := c.ids[id]; ok {
c.consumer_common.delete_internal(sk) old = c.consumer_common.delete_internal(sk)
delete(c.ids, id) delete(c.ids, id)
} }
return
}
func (c *consumer_mongo) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) {
c.onSessionInvalidated = append(c.onSessionInvalidated, cb)
} }

View File

@ -161,7 +161,12 @@ func newConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio
switch msg.Channel { switch msg.Channel {
case deleteChannel: case deleteChannel:
sk := storagekey(msg.Payload) sk := storagekey(msg.Payload)
consumer.delete(sk) old := consumer.delete(sk)
if old != nil {
for _, f := range consumer.onSessionInvalidated {
f(old.Account)
}
}
} }
} }
} }
@ -286,3 +291,7 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) {
return Authorization{}, nil return Authorization{}, nil
} }
func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) {
c.onSessionInvalidated = append(c.onSessionInvalidated, cb)
}

View File

@ -86,9 +86,7 @@ type send_msg_queue_elem struct {
msg []byte msg []byte
} }
type WebsocketHandler struct { type websocketHandlerBase struct {
WebsocketApiBroker
redisMsgChanName string redisMsgChanName string
redisCmdChanName string redisCmdChanName string
redisSync *redis.Client redisSync *redis.Client
@ -101,6 +99,11 @@ type WebsocketHandler struct {
sessionConsumer session.Consumer sessionConsumer session.Consumer
} }
type WebsocketHandler struct {
WebsocketApiBroker
websocketHandlerBase
}
type wsConfig struct { type wsConfig struct {
gocommon.StorageAddr `json:"storage"` gocommon.StorageAddr `json:"storage"`
} }
@ -152,14 +155,16 @@ func NewWebsocketHandler(consumer session.Consumer, redisUrl string) (*Websocket
}() }()
return &WebsocketHandler{ return &WebsocketHandler{
redisMsgChanName: fmt.Sprintf("_wsh_msg_%d", redisSync.Options().DB), websocketHandlerBase: websocketHandlerBase{
redisCmdChanName: fmt.Sprintf("_wsh_cmd_%d", redisSync.Options().DB), redisMsgChanName: fmt.Sprintf("_wsh_msg_%d", redisSync.Options().DB),
redisSync: redisSync, redisCmdChanName: fmt.Sprintf("_wsh_cmd_%d", redisSync.Options().DB),
connInOutChan: make(chan *wsconn), redisSync: redisSync,
deliveryChan: make(chan any, 1000), connInOutChan: make(chan *wsconn),
localDeliveryChan: make(chan any, 100), deliveryChan: make(chan any, 1000),
sendMsgChan: sendchan, localDeliveryChan: make(chan any, 100),
sessionConsumer: consumer, sendMsgChan: sendchan,
sessionConsumer: consumer,
},
}, nil }, nil
} }

View File

@ -22,10 +22,22 @@ type WebsocketPeerHandler interface {
RegisterHandlers(serveMux *http.ServeMux, prefix string) error RegisterHandlers(serveMux *http.ServeMux, prefix string) error
} }
type connEstChannelValue struct {
accid primitive.ObjectID
conn *websocket.Conn
}
type connDisChannelValue struct {
accid primitive.ObjectID
closed bool
}
type websocketPeerHandler[T PeerInterface] struct { type websocketPeerHandler[T PeerInterface] struct {
methods map[string]peerApiFuncType[T] methods map[string]peerApiFuncType[T]
createPeer func(primitive.ObjectID) T createPeer func(primitive.ObjectID) T
sessionConsumer session.Consumer sessionConsumer session.Consumer
connEstChannel chan connEstChannelValue
connDisChannel chan connDisChannelValue
} }
type PeerInterface interface { type PeerInterface interface {
@ -140,14 +152,28 @@ func NewWebsocketPeerHandler[T PeerInterface](consumer session.Consumer, creator
methods[k] = v methods[k] = v
} }
return &websocketPeerHandler[T]{ wsh := &websocketPeerHandler[T]{
sessionConsumer: consumer, sessionConsumer: consumer,
methods: methods, methods: methods,
createPeer: creator, createPeer: creator,
connEstChannel: make(chan connEstChannelValue),
connDisChannel: make(chan connDisChannelValue),
}
consumer.RegisterOnSessionInvalidated(wsh.onSessionInvalidated)
return wsh
}
func (ws *websocketPeerHandler[T]) onSessionInvalidated(accid primitive.ObjectID) {
ws.connDisChannel <- connDisChannelValue{
accid: accid,
closed: false,
} }
} }
func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux *http.ServeMux, prefix string) error { func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux *http.ServeMux, prefix string) error {
go ws.sessionMonitoring()
if *noAuthFlag { if *noAuthFlag {
serveMux.HandleFunc(prefix, ws.upgrade_nosession) serveMux.HandleFunc(prefix, ws.upgrade_nosession)
} else { } else {
@ -157,14 +183,33 @@ func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux *http.ServeMux, pre
return nil return nil
} }
func (ws *websocketPeerHandler[T]) sessionMonitoring() {
all := make(map[primitive.ObjectID]*websocket.Conn)
for {
select {
case estVal := <-ws.connEstChannel:
all[estVal.accid] = estVal.conn
case disVal := <-ws.connDisChannel:
if disVal.closed {
delete(all, disVal.accid)
} else if c := all[disVal.accid]; c != nil {
c.Close()
delete(all, disVal.accid)
}
}
}
}
func (ws *websocketPeerHandler[T]) upgrade_core(conn *websocket.Conn, accid primitive.ObjectID, nonce uint32) { func (ws *websocketPeerHandler[T]) upgrade_core(conn *websocket.Conn, accid primitive.ObjectID, nonce uint32) {
go func(c *websocket.Conn, accid primitive.ObjectID) { go func(c *websocket.Conn, accid primitive.ObjectID) {
peer := ws.createPeer(accid) peer := ws.createPeer(accid)
var closeReason string var closeReason string
peer.ClientConnected(conn) peer.ClientConnected(conn)
ws.connEstChannel <- connEstChannelValue{accid: accid, conn: conn}
defer func() { defer func() {
ws.connDisChannel <- connDisChannelValue{accid: accid, closed: true}
peer.ClientDisconnected(closeReason) peer.ClientDisconnected(closeReason)
}() }()