package server import ( "context" api "github.com/AYM1607/proglog/api/v1" "google.golang.org/grpc" ) type Config struct { CommitLog CommitLog } // This comes from the book, why is this needed? var _ api.LogServer = (*grpcServer)(nil) func NewGRPCServer(config *Config, opts ...grpc.ServerOption) (*grpc.Server, error) { gsrv := grpc.NewServer(opts...) srv, err := newgrpcServer(config) if err != nil { return nil, err } api.RegisterLogServer(gsrv, srv) return gsrv, nil } type grpcServer struct { api.UnimplementedLogServer *Config } func newgrpcServer(config *Config) (srv *grpcServer, err error) { srv = &grpcServer{ Config: config, } return srv, nil } func (s *grpcServer) Produce(ctx context.Context, req *api.ProduceRequest) ( *api.ProduceResponse, error) { offset, err := s.CommitLog.Append(req.Record) if err != nil { return nil, err } return &api.ProduceResponse{Offset: offset}, nil } func (s *grpcServer) Consume(ctx context.Context, req *api.ConsumeRequest) ( *api.ConsumeResponse, error) { record, err := s.CommitLog.Read(req.Offset) if err != nil { return nil, err } return &api.ConsumeResponse{Record: record}, nil } func (s *grpcServer) ProduceStream( stream api.Log_ProduceStreamServer, ) error { for { req, err := stream.Recv() if err != nil { return err } res, err := s.Produce(stream.Context(), req) if err != nil { return err } if err = stream.Send(res); err != nil { return err } } } func (s *grpcServer) ConsumeStream( req *api.ConsumeRequest, stream api.Log_ConsumeStreamServer, ) error { for { select { case <-stream.Context().Done(): return nil default: res, err := s.Consume(stream.Context(), req) switch err.(type) { case nil: case api.ErrOffsetOutOfRange: // This is supposed to hold off until there's more data appended to the log. // The code could return this error for records that have been deleted and it'd be stuck forever. continue default: return err } if err = stream.Send(res); err != nil { return err } req.Offset += 1 } } } type CommitLog interface { Append(*api.Record) (uint64, error) Read(uint64) (*api.Record, error) }