proglog/internal/server/server.go

204 lines
4.7 KiB
Go

package server
import (
"context"
"time"
api "github.com/AYM1607/proglog/api/v1"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"go.opencensus.io/plugin/ocgrpc"
"go.opencensus.io/stats/view"
"go.opencensus.io/trace"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
const (
objectWildCard = "*"
produceAction = "produce"
consumeAction = "consume"
)
type Config struct {
CommitLog CommitLog
Authorizer Authorizer
}
// This comes from the book, why is this needed?
var _ api.LogServer = (*grpcServer)(nil)
func NewGRPCServer(config *Config, opts ...grpc.ServerOption) (*grpc.Server, error) {
// Logging.
logger := zap.L().Named("server")
zapOpts := []grpc_zap.Option{
grpc_zap.WithDurationField(
func(duration time.Duration) zapcore.Field {
return zap.Int64(
"grpc.time_ns",
duration.Nanoseconds(),
)
},
),
}
// Tracing.
// This should not be used in production and a better sampler should be written.
trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()})
err := view.Register(ocgrpc.DefaultServerViews...)
if err != nil {
return nil, err
}
opts = append(opts,
// Streaming interceptors.
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
grpc_ctxtags.StreamServerInterceptor(),
grpc_zap.StreamServerInterceptor(logger, zapOpts...),
grpc_auth.StreamServerInterceptor(authenticate),
)),
// Unary interceptors.
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc_ctxtags.UnaryServerInterceptor(),
grpc_zap.UnaryServerInterceptor(logger, zapOpts...),
grpc_auth.UnaryServerInterceptor(authenticate),
)),
grpc.StatsHandler(&ocgrpc.ServerHandler{}),
)
gsrv := grpc.NewServer(opts...)
srv, err := newgrpcServer(config)
if err != nil {
return nil, err
}
api.RegisterLogServer(gsrv, srv)
return gsrv, nil
}
type grpcServer struct {
api.UnimplementedLogServer
*Config
}
func newgrpcServer(config *Config) (srv *grpcServer, err error) {
srv = &grpcServer{
Config: config,
}
return srv, nil
}
func (s *grpcServer) Produce(ctx context.Context, req *api.ProduceRequest) (
*api.ProduceResponse, error) {
if err := s.Authorizer.Authorize(
subject(ctx),
objectWildCard,
produceAction,
); err != nil {
return nil, err
}
offset, err := s.CommitLog.Append(req.Record)
if err != nil {
return nil, err
}
return &api.ProduceResponse{Offset: offset}, nil
}
func (s *grpcServer) Consume(ctx context.Context, req *api.ConsumeRequest) (
*api.ConsumeResponse, error) {
if err := s.Authorizer.Authorize(
subject(ctx),
objectWildCard,
consumeAction,
); err != nil {
return nil, err
}
record, err := s.CommitLog.Read(req.Offset)
if err != nil {
return nil, err
}
return &api.ConsumeResponse{Record: record}, nil
}
func (s *grpcServer) ProduceStream(
stream api.Log_ProduceStreamServer,
) error {
for {
req, err := stream.Recv()
if err != nil {
return err
}
res, err := s.Produce(stream.Context(), req)
if err != nil {
return err
}
if err = stream.Send(res); err != nil {
return err
}
}
}
func (s *grpcServer) ConsumeStream(
req *api.ConsumeRequest,
stream api.Log_ConsumeStreamServer,
) error {
for {
select {
case <-stream.Context().Done():
return nil
default:
res, err := s.Consume(stream.Context(), req)
switch err.(type) {
case nil:
case api.ErrOffsetOutOfRange:
// This is supposed to hold off until there's more data appended to the log.
// The code could return this error for records that have been deleted and it'd be stuck forever.
continue
default:
return err
}
if err = stream.Send(res); err != nil {
return err
}
req.Offset += 1
}
}
}
func authenticate(ctx context.Context) (context.Context, error) {
peer, ok := peer.FromContext(ctx)
if !ok {
return ctx, status.New(
codes.Unknown,
"could not find peer info",
).Err()
}
if peer.AuthInfo == nil {
return context.WithValue(ctx, subjectContextKey{}, ""), nil
}
tlsInfo := peer.AuthInfo.(credentials.TLSInfo)
subject := tlsInfo.State.VerifiedChains[0][0].Subject.CommonName
ctx = context.WithValue(ctx, subjectContextKey{}, subject)
return ctx, nil
}
func subject(ctx context.Context) string {
return ctx.Value(subjectContextKey{}).(string)
}
type subjectContextKey struct{}
type CommitLog interface {
Append(*api.Record) (uint64, error)
Read(uint64) (*api.Record, error)
}
type Authorizer interface {
Authorize(subject, object, action string) error
}