diff --git a/go.mod b/go.mod index fe1d182..4503b46 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ replace repositories.action2quare.com/ayo/gocommon => ./ require ( github.com/go-redis/redis/v8 v8.11.5 github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/gorilla/websocket v1.5.0 github.com/pires/go-proxyproto v0.7.0 go.mongodb.org/mongo-driver v1.11.6 golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d diff --git a/go.sum b/go.sum index c8217c5..73bf0d2 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go new file mode 100644 index 0000000..baed5ac --- /dev/null +++ b/wshandler/wshandler.go @@ -0,0 +1,617 @@ +package wshandler + +import ( + "context" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync" + + common "repositories.action2quare.com/ayo/gocommon" + "repositories.action2quare.com/ayo/gocommon/logger" + + "github.com/go-redis/redis/v8" + "github.com/gorilla/websocket" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +const ( + connStateCachePrefix = "conn_state_" + connStateScript = ` + local hosts = redis.call('keys',KEYS[1]) + for index, key in ipairs(hosts) do + local ok = redis.call('hexists', key, KEYS[2]) + if ok == 1 then + return redis.call('hget', key, KEYS[2]) + end + end + return "" + ` +) + +var ConnStateCacheKey = func() string { + hn, _ := os.Hostname() + return connStateCachePrefix + hn +}() + +type Richconn struct { + *websocket.Conn + lock sync.Mutex + alias primitive.ObjectID + tags []string + onClose map[string]func() +} + +func (rc *Richconn) AddTag(name, val string) { + rc.lock.Lock() + defer rc.lock.Unlock() + + prefix := name + "=" + for i, tag := range rc.tags { + if strings.HasPrefix(tag, prefix) { + rc.tags[i] = prefix + val + return + } + } + rc.tags = append(rc.tags, prefix+val) +} + +func (rc *Richconn) GetTag(name string) string { + rc.lock.Lock() + defer rc.lock.Unlock() + + prefix := name + "=" + for _, tag := range rc.tags { + if strings.HasPrefix(tag, prefix) { + return tag[len(prefix):] + } + } + return "" +} + +func (rc *Richconn) RemoveTag(name string, val string) { + rc.lock.Lock() + defer rc.lock.Unlock() + + whole := fmt.Sprintf("%s=%s", name, val) + for i, tag := range rc.tags { + if tag == whole { + if i == 0 && len(rc.tags) == 1 { + rc.tags = nil + } else { + lastidx := len(rc.tags) - 1 + if i < lastidx { + rc.tags[i] = rc.tags[lastidx] + } + rc.tags = rc.tags[:lastidx] + } + return + } + } +} + +func (rc *Richconn) RegistOnCloseFunc(name string, f func()) { + rc.lock.Lock() + defer rc.lock.Unlock() + + if rc.onClose == nil { + f() + return + } + rc.onClose[name] = f +} + +func (rc *Richconn) HasOnCloseFunc(name string) bool { + rc.lock.Lock() + defer rc.lock.Unlock() + + if rc.onClose == nil { + return false + } + + _, ok := rc.onClose[name] + return ok +} + +func (rc *Richconn) UnregistOnCloseFunc(name string) (out func()) { + rc.lock.Lock() + defer rc.lock.Unlock() + + if rc.onClose == nil { + return + } + out = rc.onClose[name] + delete(rc.onClose, name) + return +} + +func (rc *Richconn) WriteBytes(data []byte) error { + rc.lock.Lock() + defer rc.lock.Unlock() + return rc.WriteMessage(websocket.TextMessage, data) +} + +type DeliveryMessage struct { + Alias primitive.ObjectID + Body []byte + Command string + Conn *Richconn +} + +func (dm *DeliveryMessage) Parse(out any) error { + return json.Unmarshal(dm.Body, out) +} + +func (dm *DeliveryMessage) MarshalBinary() (data []byte, err error) { + return append(dm.Alias[:], dm.Body...), nil +} + +func (dm *DeliveryMessage) UnmarshalBinary(data []byte) error { + copy(dm.Alias[:], data[:12]) + dm.Body = data[12:] + return nil +} + +type tagconn struct { + rc *Richconn + state string +} +type tagconnsmap = map[primitive.ObjectID]*tagconn +type tagconns struct { + sync.Mutex + tagconnsmap +} + +type subhandler struct { + sync.Mutex + authCache *common.AuthCollection + conns map[primitive.ObjectID]*Richconn + aliases map[primitive.ObjectID]primitive.ObjectID + tags map[primitive.ObjectID]*tagconns + deliveryChan chan DeliveryMessage + url string + redisSync *redis.Client +} + +// WebsocketHandler : +type WebsocketHandler struct { + authCaches map[string]*subhandler + RedisSync *redis.Client +} + +type wsConfig struct { + SyncPipeline string `json:"ws_sync_pipeline"` +} + +func NewWebsocketHandler(authglobal common.AuthCollectionGlobal) (wsh *WebsocketHandler) { + authCaches := make(map[string]*subhandler) + for _, region := range authglobal.Regions() { + sh := &subhandler{ + authCache: authglobal.Get(region), + conns: make(map[primitive.ObjectID]*Richconn), + aliases: make(map[primitive.ObjectID]primitive.ObjectID), + tags: make(map[primitive.ObjectID]*tagconns), + deliveryChan: make(chan DeliveryMessage, 1000), + } + + authCaches[region] = sh + } + + var config wsConfig + common.LoadConfig(&config) + + redisSync, err := common.NewRedisClient(config.SyncPipeline, 0) + if err != nil { + panic(err) + } + + return &WebsocketHandler{ + authCaches: authCaches, + RedisSync: redisSync, + } +} + +func (ws *WebsocketHandler) Destructor() { + if ws.RedisSync != nil { + ws.RedisSync.Del(context.Background(), ConnStateCacheKey) + } +} + +func (ws *WebsocketHandler) DeliveryChannel(region string) <-chan DeliveryMessage { + return ws.authCaches[region].deliveryChan +} + +func (ws *WebsocketHandler) Conn(region string, alias primitive.ObjectID) *Richconn { + if sh := ws.authCaches[region]; sh != nil { + return sh.conns[alias] + } + return nil +} + +func (ws *WebsocketHandler) JoinTag(region string, tag primitive.ObjectID, tid primitive.ObjectID, rc *Richconn, hint string) error { + if sh := ws.authCaches[region]; sh != nil { + sh.joinTag(tag, tid, rc, hint) + } + return nil +} + +func (ws *WebsocketHandler) LeaveTag(region string, tag primitive.ObjectID, tid primitive.ObjectID) error { + if sh := ws.authCaches[region]; sh != nil { + sh.leaveTag(tag, tid) + } + return nil +} + +func (ws *WebsocketHandler) SetStateInTag(region string, tag primitive.ObjectID, tid primitive.ObjectID, state string, hint string) error { + if sh := ws.authCaches[region]; sh != nil { + sh.setStateInTag(tag, tid, state, hint) + } + return nil +} + +func (ws *WebsocketHandler) BroadcastRaw(region string, tag primitive.ObjectID, raw []byte) { + if sh := ws.authCaches[region]; sh != nil { + if cs := sh.cloneTag(tag); len(cs) > 0 { + go func(raw []byte) { + for _, c := range cs { + if c != nil { + c.WriteBytes(raw) + } + } + }(raw) + } + } +} + +func (ws *WebsocketHandler) Broadcast(region string, tag primitive.ObjectID, doc bson.M) { + raw, _ := json.Marshal(doc) + ws.BroadcastRaw(region, tag, raw) +} + +var onlineQueryScriptHash string + +func (ws *WebsocketHandler) RegisterHandlers(ctx context.Context, serveMux *http.ServeMux, prefix string) error { + ws.RedisSync.Del(context.Background(), ConnStateCacheKey) + + scriptHash, err := ws.RedisSync.ScriptLoad(context.Background(), connStateScript).Result() + if err != nil { + return err + } + onlineQueryScriptHash = scriptHash + for region, sh := range ws.authCaches { + if region == "default" { + region = "" + } + sh.url = common.MakeHttpHandlerPattern(prefix, region, "ws") + sh.redisSync = ws.RedisSync + if *common.NoSessionFlag { + serveMux.HandleFunc(sh.url, sh.upgrade_nosession) + } else { + serveMux.HandleFunc(sh.url, sh.upgrade) + } + } + + return nil +} + +func (sh *subhandler) cloneTag(tag primitive.ObjectID) (out []*Richconn) { + sh.Lock() + cs := sh.tags[tag] + sh.Unlock() + + if cs == nil { + return nil + } + + cs.Lock() + defer cs.Unlock() + + out = make([]*Richconn, 0, len(cs.tagconnsmap)) + for _, c := range cs.tagconnsmap { + out = append(out, c.rc) + } + return +} + +func (sh *subhandler) joinTag(tag primitive.ObjectID, tid primitive.ObjectID, rc *Richconn, hint string) { + sh.Lock() + cs := sh.tags[tag] + if cs == nil { + cs = &tagconns{ + tagconnsmap: make(map[primitive.ObjectID]*tagconn), + } + } + sh.Unlock() + + cs.Lock() + states := make([]bson.M, 0, len(cs.tagconnsmap)) + for tid, conn := range cs.tagconnsmap { + states = append(states, bson.M{ + "_id": tid, + "_hint": hint, + "state": conn.state, + }) + } + + cs.tagconnsmap[tid] = &tagconn{rc: rc} + cs.Unlock() + + sh.Lock() + sh.tags[tag] = cs + sh.Unlock() + + if len(states) > 0 { + s, _ := json.Marshal(states) + rc.WriteBytes(s) + } +} + +func (sh *subhandler) leaveTag(tag primitive.ObjectID, tid primitive.ObjectID) { + sh.Lock() + defer sh.Unlock() + + cs := sh.tags[tag] + if cs == nil { + return + } + + delete(cs.tagconnsmap, tid) + if len(cs.tagconnsmap) == 0 { + delete(sh.tags, tag) + } else { + sh.tags[tag] = cs + } +} + +func (sh *subhandler) setStateInTag(tag primitive.ObjectID, tid primitive.ObjectID, state string, hint string) { + sh.Lock() + cs := sh.tags[tag] + sh.Unlock() + + if cs == nil { + return + } + + cs.Lock() + defer cs.Unlock() + + if tagconn := cs.tagconnsmap[tid]; tagconn != nil { + tagconn.state = state + + var clone []*Richconn + for _, c := range cs.tagconnsmap { + clone = append(clone, c.rc) + } + raw, _ := json.Marshal(map[string]any{ + "_id": tid, + "_hint": hint, + "state": state, + }) + go func(raw []byte) { + for _, c := range clone { + c.WriteBytes(raw) + } + }(raw) + } +} + +func (wsh *WebsocketHandler) GetState(alias primitive.ObjectID) (string, error) { + state, err := wsh.RedisSync.EvalSha(context.Background(), onlineQueryScriptHash, []string{ + connStateCachePrefix + "*", alias.Hex(), + }).Result() + + if err != nil { + return "", err + } + + return state.(string), nil +} + +func (wsh *WebsocketHandler) IsOnline(alias primitive.ObjectID) (bool, error) { + state, err := wsh.GetState(alias) + if err != nil { + logger.Error("IsOnline failed. err :", err) + return false, err + } + return len(state) > 0, nil +} + +func (sh *subhandler) closeConn(accid primitive.ObjectID) { + sh.Lock() + defer sh.Unlock() + + if alias, ok := sh.aliases[accid]; ok { + if old := sh.conns[alias]; old != nil { + old.Close() + } + } +} + +func (sh *subhandler) addConn(conn *Richconn, accid primitive.ObjectID) { + sh.Lock() + defer sh.Unlock() + + sh.conns[conn.alias] = conn + sh.aliases[accid] = conn.alias +} + +func upgrade_core(sh *subhandler, conn *websocket.Conn, initState string, accid primitive.ObjectID, alias primitive.ObjectID) { + sh.closeConn(accid) + + newconn := sh.makeRichConn(alias, conn) + sh.addConn(newconn, accid) + sh.redisSync.HSet(context.Background(), ConnStateCacheKey, alias.Hex(), initState).Result() + + go func(c *Richconn, accid primitive.ObjectID, deliveryChan chan<- DeliveryMessage) { + for { + mt, p, err := c.ReadMessage() + + if err != nil { + c.Close() + break + } + + switch mt { + case websocket.BinaryMessage: + msg := DeliveryMessage{ + Alias: c.alias, + Body: p, + Conn: c, + } + deliveryChan <- msg + + case websocket.TextMessage: + msg := string(p) + opcodes := strings.Split(msg, ";") + for _, opcode := range opcodes { + if strings.HasPrefix(opcode, "ps:") { + sh.redisSync.HSet(context.Background(), ConnStateCacheKey, alias.Hex(), opcode[3:]).Result() + } else if strings.HasPrefix(opcode, "cmd:") { + cmd := opcode[4:] + msg := DeliveryMessage{ + Alias: c.alias, + Command: cmd, + Conn: c, + } + deliveryChan <- msg + } + } + + } + } + + sh.redisSync.HDel(context.Background(), ConnStateCacheKey, c.alias.Hex()).Result() + + sh.Lock() + delete(sh.conns, c.alias) + delete(sh.aliases, accid) + sh.Unlock() + + var funcs []func() + c.lock.Lock() + for _, f := range c.onClose { + funcs = append(funcs, f) + } + c.onClose = nil + c.lock.Unlock() + + for _, f := range funcs { + f() + } + }(newconn, accid, sh.deliveryChan) +} + +func (sh *subhandler) upgrade_nosession(w http.ResponseWriter, r *http.Request) { + // 클라이언트 접속 + defer func() { + s := recover() + if s != nil { + logger.Error(s) + } + io.Copy(io.Discard, r.Body) + r.Body.Close() + }() + + auth := strings.Split(r.Header.Get("Authorization"), " ") + if len(auth) != 2 { + w.WriteHeader(http.StatusBadRequest) + return + } + + if auth[0] != "Editor" { + w.WriteHeader(http.StatusBadRequest) + return + } + temp, err := hex.DecodeString(auth[1]) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + if len(temp) != len(primitive.NilObjectID) { + w.WriteHeader(http.StatusBadRequest) + return + } + + raw := (*[12]byte)(temp) + accid := primitive.ObjectID(*raw) + + var upgrader = websocket.Upgrader{} // use default options + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + alias := accid + if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { + alias = common.ParseObjectID(v) + } + + initState := r.Header.Get("As-X-Tavern-InitialState") + if len(initState) == 0 { + initState = "online" + } + + upgrade_core(sh, conn, initState, accid, alias) +} + +func (sh *subhandler) upgrade(w http.ResponseWriter, r *http.Request) { + // 클라이언트 접속 + defer func() { + s := recover() + if s != nil { + logger.Error(s) + } + io.Copy(io.Discard, r.Body) + r.Body.Close() + }() + + sk := r.Header.Get("AS-X-SESSION") + auth := strings.Split(r.Header.Get("Authorization"), " ") + if len(auth) != 2 { + //TODO : 클라이언트는 BadRequest를 받으면 로그인 화면으로 돌아가야 한다. + w.WriteHeader(http.StatusBadRequest) + return + } + authtoken := auth[1] + + accid, success := sh.authCache.IsValid(sk, authtoken) + if !success { + w.WriteHeader(http.StatusUnauthorized) + return + } + + var upgrader = websocket.Upgrader{} // use default options + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + alias := accid + if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { + alias = common.ParseObjectID(v) + } + + initState := r.Header.Get("As-X-Tavern-InitialState") + if len(initState) == 0 { + initState = "online" + } + + upgrade_core(sh, conn, initState, accid, alias) +} + +func (sh *subhandler) makeRichConn(alias primitive.ObjectID, conn *websocket.Conn) *Richconn { + rc := Richconn{ + Conn: conn, + alias: alias, + onClose: make(map[string]func()), + } + return &rc +}