diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index 7bbf1d8..1893831 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -364,14 +364,15 @@ func (ws *WebsocketHandler) mainLoop(ctx context.Context) { buffer := bytes.NewBuffer([]byte(raw.Payload)) dec := gob.NewDecoder(buffer) - if raw.Channel == ws.redisMsgChanName { + switch raw.Channel { + case ws.redisMsgChanName: var msg UpstreamMessage if err := dec.Decode(&msg); err == nil { ws.deliveryChan <- &msg } else { logger.Println("decode UpstreamMessage failed :", err) } - } else if raw.Channel == ws.redisCmdChanName { + case ws.redisCmdChanName: var cmd commandMessage if err := dec.Decode(&cmd); err == nil { ws.deliveryChan <- &cmd @@ -587,9 +588,17 @@ func (ws *WebsocketHandler) mainLoop(ctx context.Context) { if c.Conn == nil { delete(entireConns, c.sender.Accid.Hex()) go ws.ClientDisconnected(c) + } else if ws.sessionConsumer.IsRevoked(c.sender.Accid) { + c.Conn.MakeWriter().WriteControl(websocket.CloseMessage, unauthdata, time.Time{}) } else { - entireConns[c.sender.Accid.Hex()] = c - go ws.ClientConnected(c) + sk := session.AccountToSessionKey(c.sender.Accid) + auth, _ := ws.sessionConsumer.Query(sk) + if auth.Account != c.sender.Accid { + c.Conn.MakeWriter().WriteControl(websocket.CloseMessage, unauthdata, time.Time{}) + } else { + entireConns[c.sender.Accid.Hex()] = c + go ws.ClientConnected(c) + } } case accid := <-ws.forceCloseChan: