grpc
是由谷歌开源的一个高性能、通用的开源rpc
框架,具体的使用可以参考该文章 。本文主要看一下go
版本的grpc
的服务端实现。
grpc server 我们先看一下启动一个grpc server时的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 func main () { lis, err := net.Listen("tcp" , ":6060" ) if err != nil { log.Fatal(err) } s := grpc.NewServer() proto.RegisterEchoSvcServer(s, &EchoServer{}) if err := s.Serve(lis); err != nil { log.Fatal(err) } }
可以看到,第6行创建了一个grpc server,第8行将具体的服务注册到server中,然后第10行开始启动服务。
其中,RegisterEchoSvcServer
这个函数由插件protoc-gen-go
通过.proto
文件自动生成的,我们先来看一下其实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 var _EchoSvc_serviceDesc = grpc.ServiceDesc{ ServiceName: "proto.EchoSvc" , HandlerType: (*EchoSvcServer)(nil ), Methods: []grpc.MethodDesc{ { MethodName: "Echo" , Handler: _EchoSvc_Echo_Handler, }, }, Streams: []grpc.StreamDesc{}, Metadata: "echo.proto" , } func RegisterEchoSvcServer (s *grpc.Server, srv EchoSvcServer) { s.RegisterService(&_EchoSvc_serviceDesc, srv) }
上面的_EchoSvc_serviceDesc是有插件根据我们的服务声明自动生成的,grpc中的rpc方法主要有两种类型。第一种就是常见的普通的rpc方法,第二种是stream rpc方法
可以看到,实际上调用的是grpc server的服务注册方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 func (s *Server) RegisterService (sd *ServiceDesc, ss interface {}) { ht := reflect.TypeOf(sd.HandlerType).Elem() st := reflect.TypeOf(ss) if !st.Implements(ht) { grpclog.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v" , st, ht) } s.register(sd, ss) } func (s *Server) register (sd *ServiceDesc, ss interface {}) { s.mu.Lock() defer s.mu.Unlock() s.printf("RegisterService(%q)" , sd.ServiceName) if s.serve { grpclog.Fatalf("grpc: Server.RegisterService after Server.Serve for %q" , sd.ServiceName) } if _, ok := s.m[sd.ServiceName]; ok { grpclog.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q" , sd.ServiceName) } srv := &service{ server: ss, md: make (map [string ]*MethodDesc), sd: make (map [string ]*StreamDesc), mdata: sd.Metadata, } for i := range sd.Methods { d := &sd.Methods[i] srv.md[d.MethodName] = d } for i := range sd.Streams { d := &sd.Streams[i] srv.sd[d.StreamName] = d } s.m[sd.ServiceName] = srv }
可以看到,grpc的server中有一个service表,相当于http服务中的路由表。
接下来看一下,grpc server如果提供服务:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 func (s *Server) Serve (lis net.Listener) error { s.mu.Lock() s.printf("serving" ) s.serve = true s.serveWG.Add(1 ) defer func () { s.serveWG.Done() select { case <-s.quit: <-s.done default : } }() ls := &listenSocket{Listener: lis} s.lis[ls] = true s.mu.Unlock() defer func () { s.mu.Lock() if s.lis != nil && s.lis[ls] { ls.Close() delete (s.lis, ls) } s.mu.Unlock() }() for { rawConn, err := lis.Accept() if err != nil { } s.serveWG.Add(1 ) go func () { s.handleRawConn(rawConn) s.serveWG.Done() }() } } func (s *Server) handleRawConn (rawConn net.Conn) { rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) conn, authInfo, err := s.useTransportAuthenticator(rawConn) if err != nil { return } s.mu.Lock() if s.conns == nil { s.mu.Unlock() conn.Close() return } s.mu.Unlock() st := s.newHTTP2Transport(conn, authInfo) if st == nil { return } rawConn.SetDeadline(time.Time{}) if !s.addConn(st) { return } go func () { s.serveStreams(st) s.removeConn(st) }() } func (s *Server) serveStreams (st transport.ServerTransport) { defer st.Close() var wg sync.WaitGroup st.HandleStreams(func (stream *transport.Stream) { wg.Add(1 ) go func () { defer wg.Done() s.handleStream(st, stream, s.traceInfo(st, stream)) }() }, func (ctx context.Context, method string ) context .Context { if !EnableTracing { return ctx } tr := trace.New("grpc.Recv." +methodFamily(method), method) return trace.NewContext(ctx, tr) }) wg.Wait() }
接下来我们看一下server的handleStream方法,该方法处理rpc请求:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 func (s *Server) handleStream (t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) { sm := stream.Method() if sm != "" && sm[0 ] == '/' { sm = sm[1 :] } pos := strings.LastIndex(sm, "/" ) if pos == -1 { return } service := sm[:pos] method := sm[pos+1 :] srv, ok := s.m[service] if !ok { if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { s.processStreamingRPC(t, stream, nil , unknownDesc, trInfo) return } return } if md, ok := srv.md[method]; ok { s.processUnaryRPC(t, stream, srv, md, trInfo) return } if sd, ok := srv.sd[method]; ok { s.processStreamingRPC(t, stream, srv, sd, trInfo) return } if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { s.processStreamingRPC(t, stream, nil , unknownDesc, trInfo) return } }
限于篇幅,我们这里主要看一下processUnaryRPC
方法,processStreamingRPC
方法大同小异:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 func (s *Server) processUnaryRPC (t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { var comp, decomp encoding.Compressor var cp Compressor var dc Decompressor if s.opts.cp != nil { cp = s.opts.cp stream.SetSendCompress(cp.Type()) } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { comp = encoding.GetCompressor(rc) if comp != nil { stream.SetSendCompress(rc) } } var payInfo *payloadInfo if sh != nil || binlog != nil { payInfo = &payloadInfo{} } d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) if err != nil { return err } df := func (v interface {}) error { if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v" , err) } return nil } ctx := NewContextWithServerTransportStream(stream.Context(), stream) reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt) if appErr != nil { appStatus, ok := status.FromError(appErr) if !ok { appErr = status.Error(codes.Unknown, appErr.Error()) appStatus, _ = status.FromError(appErr) } if e := t.WriteStatus(stream, appStatus); e != nil { grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status: %v" , e) } return appErr } opts := &transport.Options{Last: true } if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil { return err } err = t.WriteStatus(stream, status.New(codes.OK, "" )) return err }
上面的代码略有删减,主要是删掉一些和统计、trace以及日志相关的代码。主要的逻辑就是从stream读取请求参参数,反序列化后调用methodDesc中的handler方法,然后把返回的内容序列化后写入stream返回给客户端。
我们知道,grpc是基于http2协议的,因此也是存在header
的,grpc和http一样,可以设置和获取请求的header。在服务端,主要有获取客户端传递过来的header以及传递header给客户端两个操作。
我们先看上面出现的NewContextWithServerTransportStream
方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 type ServerTransportStream interface { Method() string SetHeader(md metadata.MD) error SendHeader(md metadata.MD) error SetTrailer(md metadata.MD) error } func NewContextWithServerTransportStream (ctx context.Context, stream ServerTransportStream) context .Context { return context.WithValue(ctx, streamKey{}, stream) }
我们可以看到在processUnaryRPC
方法中,对该方法的调用如下:
1 2 3 4 5 6 func (s *Server) processUnaryRPC (t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { ctx := NewContextWithServerTransportStream(stream.Context(), stream) reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt) }
可以看到,传入的是当前请求的stream
的context
,接下来看一下stream
的context
创建:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 func (t *http2Server) operateHeaders (frame *http2.MetaHeadersFrame, handle func (*Stream) , traceCtx func (context.Context, string ) context .Context ) (fatal bool ) { streamID := frame.Header().StreamID state := decodeState{serverSide: true } if err := state.decodeHeader(frame); err != nil { return false } buf := newRecvBuffer() s := &Stream{ id: streamID, st: t, buf: buf, fc: &inFlow{limit: uint32 (t.initialWindowSize)}, recvCompress: state.encoding, method: state.method, contentSubtype: state.contentSubtype, } if len (state.mdata) > 0 { s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata) } handle(s) return false } func NewIncomingContext (ctx context.Context, md MD) context .Context { return context.WithValue(ctx, mdIncomingKey{}, md) }
当收到一个Header
帧,就表明有新的rpc请求到来,这时候就会解析header帧并创建stream,在创建stream的时候,会把用户自定义的header字段保存到stream.context中
在我们实际编码时,可以通过metadata
包来读取客户端传递过来的header
:
1 2 3 4 5 6 7 8 9 10 func (EchoServer) Echo (ctx context.Context, req *proto.EchoReq) (resp *proto.EchoResp, err error) { md, ok := metadata.FromIncomingContext(ctx) if ok { log.Printf("%s: %v" , md.Get("key" )) } return &proto.EchoResp{ Msg: VERSION, }, err }
而设置header返回给客户端可以如下:
1 2 3 4 5 6 7 8 func (EchoServer) Echo (ctx context.Context, req *proto.EchoReq) (resp *proto.EchoResp, err error) { grpc.SetHeader(ctx, metadata.Pairs("key1" , "val1" )) return &proto.EchoResp{ Msg: VERSION, }, err }
接下来看一下写回返回内容给客户端的逻辑:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 func (s *Server) sendResponse (t transport.ServerTransport, stream *transport.Stream, msg interface {}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { data, err := encode(s.getCodec(stream.ContentSubtype()), msg) if err != nil { grpclog.Errorln("grpc: server failed to encode response: " , err) return err } compData, err := compress(data, cp, comp) if err != nil { grpclog.Errorln("grpc: server failed to compress response: " , err) return err } hdr, payload := msgHeader(data, compData) if len (payload) > s.opts.maxSendMessageSize { return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)" , len (payload), s.opts.maxSendMessageSize) } err = t.Write(stream, hdr, payload, opts) if err == nil && s.opts.statsHandler != nil { s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false , msg, data, payload, time.Now())) } return err } func (t *http2Server) Write (s *Stream, hdr []byte , data []byte , opts *Options) error { if !s.isHeaderSent() { if err := t.WriteHeader(s, nil ); err != nil { return status.Errorf(codes.Internal, "transport: %v" , err) } } else { } emptyLen := http2MaxFrameLen - len (hdr) if emptyLen > len (data) { emptyLen = len (data) } hdr = append (hdr, data[:emptyLen]...) data = data[emptyLen:] df := &dataFrame{ streamID: s.id, h: hdr, d: data, onEachWrite: func () { atomic.StoreUint32(&t.resetPingStrikes, 1 ) }, } if err := s.wq.get(int32 (len (hdr) + len (data))); err != nil { select { case <-t.ctx.Done(): return ErrConnClosing default : } return ContextErr(s.ctx.Err()) } return t.controlBuf.put(df) }
最后,看一下methodDesc中的handler,这个是由插件自动生成的包装方法:
1 2 3 4 5 func (s *Server) processUnaryRPC (t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt) }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 func _EchoSvc_Echo_Handler (srv interface {}, ctx context.Context, dec func (interface {}) error , interceptor grpc .UnaryServerInterceptor ) (interface {}, error) { in := new (EchoReq) if err := dec(in); err != nil { return nil , err } if interceptor == nil { return srv.(EchoSvcServer).Echo(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/proto.EchoSvc/Echo" , } handler := func (ctx context.Context, req interface {}) (interface {}, error) { return srv.(EchoSvcServer).Echo(ctx, req.(*EchoReq)) } return interceptor(ctx, in, info, handler) }
可以看到,用户创建server
时,如果设置了interceptor
选项,那么在执行具体的服务方法前,会先执行用户设置的interceptor
,声明如下:
1 type UnaryServerInterceptor func (ctx context.Context, req interface {}, info *UnaryServerInfo, handler UnaryHandler) (resp interface {}, err error)
在interceptor
中,可以做一些通用处理,比如日志记录,异常处理或者请求拦截等