diff --git a/wshandler/wshandler_peer.go b/wshandler/wshandler_peer.go index bc300bc..7d90c03 100644 --- a/wshandler/wshandler_peer.go +++ b/wshandler/wshandler_peer.go @@ -173,9 +173,9 @@ func (ws *websocketPeerHandler[T]) onSessionInvalidated(accid primitive.ObjectID func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux *http.ServeMux, prefix string) error { go ws.sessionMonitoring() - + if *noAuthFlag { - serveMux.HandleFunc(prefix, ws.upgrade_nosession) + serveMux.HandleFunc(prefix, ws.upgrade_noauth) } else { serveMux.HandleFunc(prefix, ws.upgrade) } @@ -281,7 +281,7 @@ func (ws *websocketPeerHandler[T]) upgrade_core(conn *websocket.Conn, accid prim }(conn, accid) } -func (ws *websocketPeerHandler[T]) upgrade_nosession(w http.ResponseWriter, r *http.Request) { +func (ws *websocketPeerHandler[T]) upgrade_noauth(w http.ResponseWriter, r *http.Request) { // 클라이언트 접속 defer func() { s := recover() @@ -292,25 +292,37 @@ func (ws *websocketPeerHandler[T]) upgrade_nosession(w http.ResponseWriter, r *h r.Body.Close() }() - auth := strings.Split(r.Header.Get("Authorization"), " ") - if len(auth) != 2 { - w.WriteHeader(http.StatusBadRequest) - return + sk := r.Header.Get("AS-X-SESSION") + var accid primitive.ObjectID + if len(sk) > 0 { + logger.Println("WebsocketHandler.upgrade sk :", sk) + authinfo, err := ws.sessionConsumer.Query(sk) + if err == nil { + accid = authinfo.Account + } } - temp, err := hex.DecodeString(auth[1]) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } + if accid.IsZero() { + auth := strings.Split(r.Header.Get("Authorization"), " ") + if len(auth) != 2 { + w.WriteHeader(http.StatusBadRequest) + return + } - if len(temp) != len(primitive.NilObjectID) { - w.WriteHeader(http.StatusBadRequest) - return - } + temp, err := hex.DecodeString(auth[1]) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } - raw := (*[12]byte)(temp) - accid := primitive.ObjectID(*raw) + if len(temp) != len(primitive.NilObjectID) { + w.WriteHeader(http.StatusBadRequest) + return + } + + raw := (*[12]byte)(temp) + accid = primitive.ObjectID(*raw) + } var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil)