diff --git a/rpc/rpc.go b/rpc/rpc.go new file mode 100644 index 0000000..69e3515 --- /dev/null +++ b/rpc/rpc.go @@ -0,0 +1,190 @@ +package rpc + +import ( + "bytes" + "context" + "encoding/gob" + "errors" + "fmt" + "reflect" + "runtime" + "strings" + "time" + + "github.com/go-redis/redis/v8" + "go.mongodb.org/mongo-driver/bson/primitive" + "repositories.action2quare.com/ayo/gocommon/logger" +) + +type Receiver interface { + TargetExists(primitive.ObjectID) bool +} + +type receiverManifest struct { + r Receiver + methods map[string]reflect.Method +} + +type rpcEngine struct { + receivers map[string]receiverManifest + publish func([]byte) error +} + +var engine = rpcEngine{ + receivers: make(map[string]receiverManifest), +} + +func RegistReceiver(ptr Receiver) { + rname := reflect.TypeOf(ptr).Elem().Name() + rname = fmt.Sprintf("(*%s)", rname) + + methods := make(map[string]reflect.Method) + for i := 0; i < reflect.TypeOf(ptr).NumMethod(); i++ { + method := reflect.TypeOf(ptr).Method(i) + methods[method.Name] = method + } + engine.receivers[rname] = receiverManifest{ + r: ptr, + methods: methods, + } +} + +func Start(ctx context.Context, redisClient *redis.Client) { + if engine.publish != nil { + return + } + + pubsubName := primitive.NewObjectID().Hex()[6:] + engine.publish = func(s []byte) error { + _, err := redisClient.Publish(ctx, pubsubName, s).Result() + return err + } + + go engine.loop(ctx, redisClient, pubsubName) +} + +func (re *rpcEngine) callFromMessage(msg *redis.Message) { + defer func() { + r := recover() + if r != nil { + logger.Error(r) + } + }() + + encoded := []byte(msg.Payload) + var target primitive.ObjectID + copy(target[:], encoded[:12]) + + encoded = encoded[12:] + for i, c := range encoded { + if c == ')' { + if manifest, ok := re.receivers[string(encoded[:i+1])]; ok { + // 리시버 찾음 + if manifest.r.TargetExists(target) { + // 이 리시버가 타겟을 가지고 있음 + encoded = encoded[i+1:] + decoder := gob.NewDecoder(bytes.NewBuffer(encoded)) + var params []any + if decoder.Decode(¶ms) == nil { + method := manifest.methods[params[0].(string)] + args := []reflect.Value{ + reflect.ValueOf(manifest.r), + } + for _, arg := range params[1:] { + args = append(args, reflect.ValueOf(arg)) + } + method.Func.Call(args) + } + } + } + } + } +} + +func (re *rpcEngine) loop(ctx context.Context, redisClient *redis.Client, chanName string) { + defer func() { + r := recover() + if r != nil { + logger.Error(r) + } + }() + + pubsub := redisClient.Subscribe(ctx, chanName) + for { + if ctx.Err() != nil { + return + } + + if pubsub == nil { + pubsub = redisClient.Subscribe(ctx, chanName) + } + + msg, err := pubsub.ReceiveMessage(ctx) + + if err != nil { + if err == redis.ErrClosed { + time.Sleep(time.Second) + } + pubsub = nil + } else { + re.callFromMessage(msg) + } + } +} + +var errNoReceiver = errors.New("no receiver") + +type RpcCallContext struct { + r Receiver + t primitive.ObjectID +} + +var ErrCanExecuteHere = errors.New("go ahead") + +func MakeCallContext(r Receiver) RpcCallContext { return RpcCallContext{r: r} } +func (c *RpcCallContext) Target(t primitive.ObjectID) { c.t = t } +func (c *RpcCallContext) Call(args ...any) error { + if c.r.TargetExists(c.t) { + // 여기 있네? + return ErrCanExecuteHere + } + + pc := make([]uintptr, 1) + n := runtime.Callers(2, pc[:]) + if n < 1 { + return errNoReceiver + } + + frame, _ := runtime.CallersFrames(pc).Next() + prf := strings.Split(frame.Function, ".") + rname := prf[1] + funcname := prf[2] + + serialized, err := encode(c.t, rname, funcname, args...) + if err != nil { + return err + } + + return engine.publish(serialized) +} + +func encode(target primitive.ObjectID, receiver string, funcname string, args ...any) ([]byte, error) { + buff := new(bytes.Buffer) + + // 타겟을 가장 먼저 기록 + buff.Write(target[:]) + + // receiver + buff.Write([]byte(receiver)) + + // 다음 call context 기록 + m := append([]any{funcname}, args...) + encoder := gob.NewEncoder(buff) + err := encoder.Encode(m) + if err != nil { + logger.Error("rpcCallContext.send err :", err) + return nil, err + } + + return buff.Bytes(), nil +} diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index cd270a5..23c9550 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -381,6 +381,7 @@ func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID } if messageType == websocket.CloseMessage { + sh.callReceiver(accid, c.alias, CloseMessage, r) break }