From 45883436c568b01049ea12f71383c09a3411977b Mon Sep 17 00:00:00 2001 From: mountain Date: Thu, 28 Dec 2023 10:57:42 +0900 Subject: [PATCH] =?UTF-8?q?=EC=84=B8=EC=85=98=20=EB=AC=B4=ED=9A=A8?= =?UTF-8?q?=ED=99=94=EC=8B=9C=20=EC=A0=91=EC=86=8D=20=EC=A2=85=EB=A3=8C=20?= =?UTF-8?q?=EC=B2=98=EB=A6=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- session/impl_redis.go | 9 ++++++++ wshandler/wshandler.go | 37 ++++++++++++++++++++++++++++-- wshandler/wshandler_peer.go | 45 +++++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 2 deletions(-) diff --git a/session/impl_redis.go b/session/impl_redis.go index 2d8ad9c..c6c4baf 100644 --- a/session/impl_redis.go +++ b/session/impl_redis.go @@ -257,6 +257,15 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) { defer c.lock.Unlock() sk := publickey_to_storagekey(publickey(pk)) + + if _, deleted := c.stages[0].deleted[sk]; deleted { + return Authorization{}, nil + } + + if _, deleted := c.stages[1].deleted[sk]; deleted { + return Authorization{}, nil + } + ok, err := c.redisClient.Expire(c.ctx, string(sk), c.ttl).Result() if err == redis.Nil { logger.Println("session consumer touch :", pk, err) diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index f9c47b4..d3a765c 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -93,6 +93,7 @@ type websocketHandlerBase struct { connInOutChan chan *wsconn deliveryChan chan any localDeliveryChan chan any + forceCloseChan chan primitive.ObjectID sendMsgChan chan send_msg_queue_elem connWaitGroup sync.WaitGroup @@ -154,7 +155,7 @@ func NewWebsocketHandler(consumer session.Consumer, redisUrl string) (*Websocket } }() - return &WebsocketHandler{ + ws := &WebsocketHandler{ websocketHandlerBase: websocketHandlerBase{ redisMsgChanName: fmt.Sprintf("_wsh_msg_%d", redisSync.Options().DB), redisCmdChanName: fmt.Sprintf("_wsh_cmd_%d", redisSync.Options().DB), @@ -162,10 +163,13 @@ func NewWebsocketHandler(consumer session.Consumer, redisUrl string) (*Websocket connInOutChan: make(chan *wsconn), deliveryChan: make(chan any, 1000), localDeliveryChan: make(chan any, 100), + forceCloseChan: make(chan primitive.ObjectID), sendMsgChan: sendchan, sessionConsumer: consumer, }, - }, nil + } + consumer.RegisterOnSessionInvalidated(ws.onSessionInvalidated) + return ws, nil } func (ws *WebsocketHandler) Start(ctx context.Context) { @@ -206,6 +210,10 @@ func (ws *WebsocketHandler) LeaveRoom(room string, accid primitive.ObjectID) { } } +func (ws *WebsocketHandler) onSessionInvalidated(accid primitive.ObjectID) { + ws.forceCloseChan <- accid +} + func (ws *WebsocketHandler) mainLoop(ctx context.Context) { defer func() { ws.connWaitGroup.Done() @@ -358,6 +366,9 @@ func (ws *WebsocketHandler) mainLoop(ctx context.Context) { } // 유저에게서 온 메세지, 소켓 연결/해체 처리 + unauthdata := []byte{0x03, 0xec} + unauthdata = append(unauthdata, []byte("unauthorized")...) + for { buffer := bytes.NewBuffer(make([]byte, 0, 1024)) buffer.Reset() @@ -442,6 +453,11 @@ func (ws *WebsocketHandler) mainLoop(ctx context.Context) { logger.Println("ClientConnected :", c.sender.Alias) go ws.ClientConnected(c) } + + case accid := <-ws.forceCloseChan: + if conn := entireConns[accid.Hex()]; conn != nil { + conn.WriteControl(websocket.CloseMessage, unauthdata, time.Time{}) + } } } } @@ -520,6 +536,18 @@ func (ws *WebsocketHandler) upgrade_nosession(w http.ResponseWriter, r *http.Req raw := (*[12]byte)(temp) accid := primitive.ObjectID(*raw) + sk := r.Header.Get("AS-X-SESSION") + authinfo, err := ws.sessionConsumer.Query(sk) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + if authinfo.Account != accid { + w.WriteHeader(http.StatusUnauthorized) + return + } + var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -558,6 +586,11 @@ func (ws *WebsocketHandler) upgrade(w http.ResponseWriter, r *http.Request) { return } + if authinfo.Account.IsZero() { + w.WriteHeader(http.StatusUnauthorized) + return + } + var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil { diff --git a/wshandler/wshandler_peer.go b/wshandler/wshandler_peer.go index 1cecf45..276f886 100644 --- a/wshandler/wshandler_peer.go +++ b/wshandler/wshandler_peer.go @@ -22,10 +22,23 @@ type WebsocketPeerHandler interface { RegisterHandlers(serveMux *http.ServeMux, prefix string) error } +type peerCtorChannelValue struct { + accid primitive.ObjectID + conn *websocket.Conn +} + +type peerDtorChannelValue struct { + accid primitive.ObjectID + closed bool +} + type websocketPeerHandler[T PeerInterface] struct { methods map[string]peerApiFuncType[T] createPeer func(primitive.ObjectID) T sessionConsumer session.Consumer + + peerCtorChannel chan peerCtorChannelValue + peerDtorChannel chan peerDtorChannelValue } type PeerInterface interface { @@ -144,8 +157,11 @@ func NewWebsocketPeerHandler[T PeerInterface](consumer session.Consumer, creator sessionConsumer: consumer, methods: methods, createPeer: creator, + peerCtorChannel: make(chan peerCtorChannelValue), + peerDtorChannel: make(chan peerDtorChannelValue), } + consumer.RegisterOnSessionInvalidated(wsh.onSessionInvalidated) return wsh } @@ -155,18 +171,47 @@ func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux *http.ServeMux, pre } else { serveMux.HandleFunc(prefix, ws.upgrade) } + go ws.sessionMonitoring() return nil } +func (ws *websocketPeerHandler[T]) onSessionInvalidated(accid primitive.ObjectID) { + ws.peerDtorChannel <- peerDtorChannelValue{ + accid: accid, + closed: false, + } +} + +func (ws *websocketPeerHandler[T]) sessionMonitoring() { + all := make(map[primitive.ObjectID]*websocket.Conn) + unauthdata := []byte{0x03, 0xec} + unauthdata = append(unauthdata, []byte("unauthorized")...) + for { + select { + case estVal := <-ws.peerCtorChannel: + all[estVal.accid] = estVal.conn + case disVal := <-ws.peerDtorChannel: + if disVal.closed { + delete(all, disVal.accid) + } else if c := all[disVal.accid]; c != nil { + c.WriteControl(websocket.CloseMessage, unauthdata, time.Time{}) + delete(all, disVal.accid) + } + } + } +} + func (ws *websocketPeerHandler[T]) upgrade_core(conn *websocket.Conn, accid primitive.ObjectID, nonce uint32) { go func(c *websocket.Conn, accid primitive.ObjectID) { peer := ws.createPeer(accid) var closeReason string peer.ClientConnected(conn) + ws.peerCtorChannel <- peerCtorChannelValue{accid: accid, conn: conn} defer func() { + ws.peerDtorChannel <- peerDtorChannelValue{accid: accid, closed: true} peer.ClientDisconnected(closeReason) }()