package wshandler import ( "encoding/hex" "encoding/json" "fmt" "io" "net/http" "reflect" "strings" "time" "go.mongodb.org/mongo-driver/bson/primitive" "repositories.action2quare.com/ayo/gocommon/logger" "repositories.action2quare.com/ayo/gocommon/session" "github.com/gorilla/websocket" ) type WebsocketPeerHandler interface { RegisterHandlers(serveMux *http.ServeMux, prefix string) error } type peerCtorChannelValue struct { accid primitive.ObjectID conn *websocket.Conn } type peerDtorChannelValue struct { accid primitive.ObjectID sk string } type websocketPeerHandler[T PeerInterface] struct { methods map[string]peerApiFuncType[T] createPeer func(primitive.ObjectID) T sessionConsumer session.Consumer peerCtorChannel chan peerCtorChannelValue peerDtorChannel chan peerDtorChannelValue } type PeerInterface interface { ClientDisconnected(string) ClientConnected(*websocket.Conn) } type peerApiFuncType[T PeerInterface] func(T, io.Reader) (any, error) type websocketPeerApiHandler[T PeerInterface] struct { methods map[string]peerApiFuncType[T] originalReceiverName string } func (hc *websocketPeerHandler[T]) call(recv T, funcname string, r io.Reader) (v any, e error) { defer func() { r := recover() if r != nil { logger.Error(r) e = fmt.Errorf("%v", r) } }() if found := hc.methods[funcname]; found != nil { return found(recv, r) } return nil, fmt.Errorf("api is not found : %s", funcname) } func makeWebsocketPeerApiHandler[T PeerInterface]() websocketPeerApiHandler[T] { methods := make(map[string]peerApiFuncType[T]) var archetype T tp := reflect.TypeOf(archetype) for i := 0; i < tp.NumMethod(); i++ { method := tp.Method(i) if method.Type.In(0) != tp { continue } if method.Name == ClientDisconnected { continue } var intypes []reflect.Type for i := 1; i < method.Type.NumIn(); i++ { intypes = append(intypes, method.Type.In(i)) } var outconv func([]reflect.Value) (any, error) if method.Type.NumOut() == 0 { outconv = func([]reflect.Value) (any, error) { return nil, nil } } else if method.Type.NumOut() == 1 { if method.Type.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { outconv = func(out []reflect.Value) (any, error) { if out[0].Interface() == nil { return nil, nil } return nil, out[0].Interface().(error) } } else { outconv = func(out []reflect.Value) (any, error) { return out[0].Interface(), nil } } } else if method.Type.NumOut() == 2 && method.Type.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { outconv = func(out []reflect.Value) (any, error) { if out[1].Interface() == nil { return out[0].Interface(), nil } return out[0].Interface(), out[1].Interface().(error) } } methods[method.Name] = func(recv T, r io.Reader) (any, error) { decoder := json.NewDecoder(r) inargs := make([]any, len(intypes)) for i, intype := range intypes { zerovalueptr := reflect.New(intype) inargs[i] = zerovalueptr.Interface() } err := decoder.Decode(&inargs) if err != nil { return nil, err } reflectargs := make([]reflect.Value, 0, len(inargs)+1) reflectargs = append(reflectargs, reflect.ValueOf(recv)) for _, p := range inargs { reflectargs = append(reflectargs, reflect.ValueOf(p).Elem()) } return outconv(method.Func.Call(reflectargs)) } } return websocketPeerApiHandler[T]{ methods: methods, originalReceiverName: tp.Elem().Name(), } } func NewWebsocketPeerHandler[T PeerInterface](consumer session.Consumer, creator func(primitive.ObjectID) T) WebsocketPeerHandler { methods := make(map[string]peerApiFuncType[T]) receiver := makeWebsocketPeerApiHandler[T]() for k, v := range receiver.methods { logger.Printf("ws api registered : %s.%s\n", receiver.originalReceiverName, k) methods[k] = v } wsh := &websocketPeerHandler[T]{ sessionConsumer: consumer, methods: methods, createPeer: creator, peerCtorChannel: make(chan peerCtorChannelValue), peerDtorChannel: make(chan peerDtorChannelValue), } consumer.RegisterOnSessionInvalidated(wsh.onSessionInvalidated) return wsh } func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux *http.ServeMux, prefix string) error { if *noAuthFlag { serveMux.HandleFunc(prefix, ws.upgrade_noauth) } else { serveMux.HandleFunc(prefix, ws.upgrade) } go ws.sessionMonitoring() return nil } func (ws *websocketPeerHandler[T]) onSessionInvalidated(accid primitive.ObjectID) { ws.peerDtorChannel <- peerDtorChannelValue{ accid: accid, } } func (ws *websocketPeerHandler[T]) sessionMonitoring() { all := make(map[primitive.ObjectID]*websocket.Conn) unauthdata := []byte{0x03, 0xec} unauthdata = append(unauthdata, []byte("unauthorized")...) for { select { case estVal := <-ws.peerCtorChannel: all[estVal.accid] = estVal.conn case disVal := <-ws.peerDtorChannel: if c := all[disVal.accid]; c != nil { c.WriteControl(websocket.CloseMessage, unauthdata, time.Time{}) delete(all, disVal.accid) } if len(disVal.sk) > 0 { ws.sessionConsumer.Revoke(disVal.sk) delete(all, disVal.accid) } } } } func (ws *websocketPeerHandler[T]) upgrade_core(conn *websocket.Conn, accid primitive.ObjectID, sk string) { go func(c *websocket.Conn, accid primitive.ObjectID, sk string) { peer := ws.createPeer(accid) var closeReason string peer.ClientConnected(conn) ws.peerCtorChannel <- peerCtorChannelValue{accid: accid, conn: conn} defer func() { ws.peerDtorChannel <- peerDtorChannelValue{accid: accid, sk: sk} peer.ClientDisconnected(closeReason) }() response := make([]byte, 255) for { response = response[:5] messageType, r, err := c.NextReader() if err != nil { if ce, ok := err.(*websocket.CloseError); ok { closeReason = ce.Text } c.Close() break } if messageType == websocket.CloseMessage { closeMsg, _ := io.ReadAll(r) closeReason = string(closeMsg) break } if messageType == websocket.BinaryMessage { var flag [1]byte r.Read(flag[:]) if flag[0] == 0xff { // nonce r.Read(response[1:5]) var size [1]byte r.Read(size[:]) cmd := make([]byte, size[0]) r.Read(cmd) result, err := ws.call(peer, string(cmd), r) if err != nil { response[0] = 21 // 21 : Negative Ack response = append(response, []byte(err.Error())...) } else { response[0] = 6 // 6 : Acknowledgement switch result := result.(type) { case string: response = append(response, []byte(result)...) case int8, int16, int32, int64, uint8, uint16, uint32, uint64: response = append(response, []byte(fmt.Sprintf("%d", result))...) case float32, float64: response = append(response, []byte(fmt.Sprintf("%f", result))...) case []byte: response = append(response, result...) default: j, _ := json.Marshal(result) response = append(response, j...) } } pmsg, err := websocket.NewPreparedMessage(websocket.BinaryMessage, response) if err != nil { logger.Println("websocket.NewPreparedMessage failed :", err) } else { c.WritePreparedMessage(pmsg) } } else { cmd := make([]byte, flag[0]) r.Read(cmd) ws.call(peer, string(cmd), r) } } } }(conn, accid, sk) } func (ws *websocketPeerHandler[T]) upgrade_noauth(w http.ResponseWriter, r *http.Request) { // 클라이언트 접속 defer func() { s := recover() if s != nil { logger.Error(s) } io.Copy(io.Discard, r.Body) r.Body.Close() }() sk := r.Header.Get("AS-X-SESSION") var accid primitive.ObjectID if len(sk) > 0 { authinfo, err := ws.sessionConsumer.Query(sk) if err == nil { accid = authinfo.Account } } if accid.IsZero() { auth := strings.Split(r.Header.Get("Authorization"), " ") if len(auth) != 2 { w.WriteHeader(http.StatusBadRequest) return } temp, err := hex.DecodeString(auth[1]) if err != nil { w.WriteHeader(http.StatusBadRequest) return } if len(temp) != len(primitive.NilObjectID) { w.WriteHeader(http.StatusBadRequest) return } raw := (*[12]byte)(temp) accid = primitive.ObjectID(*raw) } var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } // var alias string // if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { // vt, _ := base64.StdEncoding.DecodeString(v) // alias = string(vt) // } else { // alias = accid.Hex() // } ws.upgrade_core(conn, accid, sk) } func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Request) { // 클라이언트 접속 defer func() { s := recover() if s != nil { logger.Error(s) } io.Copy(io.Discard, r.Body) r.Body.Close() }() sk := r.Header.Get("AS-X-SESSION") authinfo, err := ws.sessionConsumer.Query(sk) if err != nil { w.WriteHeader(http.StatusInternalServerError) logger.Error("authorize query failed :", err) return } if authinfo.Account.IsZero() || authinfo.Invalidated() { w.WriteHeader(http.StatusUnauthorized) return } var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } // var alias string // if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { // vt, _ := base64.StdEncoding.DecodeString(v) // alias = string(vt) // } else { // alias = authinfo.Account.Hex() // } ws.upgrade_core(conn, authinfo.Account, sk) }