Adds auth middleware for all rpcs with tests.
This commit is contained in:
parent
fd53846238
commit
64e6faecae
4 changed files with 146 additions and 4 deletions
|
|
@ -4,17 +4,41 @@ import (
|
|||
"context"
|
||||
|
||||
api "github.com/AYM1607/proglog/api/v1"
|
||||
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
objectWildCard = "*"
|
||||
produceAction = "produce"
|
||||
consumeAction = "consume"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
CommitLog CommitLog
|
||||
CommitLog CommitLog
|
||||
Authorizer Authorizer
|
||||
}
|
||||
|
||||
// This comes from the book, why is this needed?
|
||||
var _ api.LogServer = (*grpcServer)(nil)
|
||||
|
||||
func NewGRPCServer(config *Config, opts ...grpc.ServerOption) (*grpc.Server, error) {
|
||||
opts = append(opts,
|
||||
// Streaming interceptors.
|
||||
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
|
||||
grpc_auth.StreamServerInterceptor(authenticate),
|
||||
)),
|
||||
// Unary interceptors.
|
||||
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
|
||||
grpc_auth.UnaryServerInterceptor(authenticate),
|
||||
)),
|
||||
)
|
||||
gsrv := grpc.NewServer(opts...)
|
||||
srv, err := newgrpcServer(config)
|
||||
if err != nil {
|
||||
|
|
@ -38,6 +62,13 @@ func newgrpcServer(config *Config) (srv *grpcServer, err error) {
|
|||
|
||||
func (s *grpcServer) Produce(ctx context.Context, req *api.ProduceRequest) (
|
||||
*api.ProduceResponse, error) {
|
||||
if err := s.Authorizer.Authorize(
|
||||
subject(ctx),
|
||||
objectWildCard,
|
||||
produceAction,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset, err := s.CommitLog.Append(req.Record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -47,6 +78,13 @@ func (s *grpcServer) Produce(ctx context.Context, req *api.ProduceRequest) (
|
|||
|
||||
func (s *grpcServer) Consume(ctx context.Context, req *api.ConsumeRequest) (
|
||||
*api.ConsumeResponse, error) {
|
||||
if err := s.Authorizer.Authorize(
|
||||
subject(ctx),
|
||||
objectWildCard,
|
||||
consumeAction,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record, err := s.CommitLog.Read(req.Offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -99,7 +137,35 @@ func (s *grpcServer) ConsumeStream(
|
|||
}
|
||||
}
|
||||
|
||||
func authenticate(ctx context.Context) (context.Context, error) {
|
||||
peer, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return ctx, status.New(
|
||||
codes.Unknown,
|
||||
"could not find peer info",
|
||||
).Err()
|
||||
}
|
||||
if peer.AuthInfo == nil {
|
||||
return context.WithValue(ctx, subjectContextKey{}, ""), nil
|
||||
}
|
||||
tlsInfo := peer.AuthInfo.(credentials.TLSInfo)
|
||||
subject := tlsInfo.State.VerifiedChains[0][0].Subject.CommonName
|
||||
ctx = context.WithValue(ctx, subjectContextKey{}, subject)
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func subject(ctx context.Context) string {
|
||||
return ctx.Value(subjectContextKey{}).(string)
|
||||
}
|
||||
|
||||
type subjectContextKey struct{}
|
||||
|
||||
type CommitLog interface {
|
||||
Append(*api.Record) (uint64, error)
|
||||
Read(uint64) (*api.Record, error)
|
||||
}
|
||||
|
||||
type Authorizer interface {
|
||||
Authorize(subject, object, action string) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@ import (
|
|||
"testing"
|
||||
|
||||
api "github.com/AYM1607/proglog/api/v1"
|
||||
"github.com/AYM1607/proglog/internal/auth"
|
||||
"github.com/AYM1607/proglog/internal/config"
|
||||
"github.com/AYM1607/proglog/internal/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
|
@ -25,6 +27,7 @@ func TestServer(t *testing.T) {
|
|||
"produce/consume a message to/from the log succeeds": testProduceConsume,
|
||||
"produce/consume stream succeeds": testProduceConsumeStream,
|
||||
"consume past a log boundary fails": testConsumePastBoundary,
|
||||
"unauthorized fails": testUnauthorized,
|
||||
} {
|
||||
t.Run(scenario, func(t *testing.T) {
|
||||
rootClient,
|
||||
|
|
@ -96,8 +99,10 @@ func setupTest(t *testing.T, fn func(*Config)) (
|
|||
clog, err := log.NewLog(dir, log.Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
authorizer := auth.New(config.ACLModelFile, config.ACLPolicyFile)
|
||||
cfg = &Config{
|
||||
CommitLog: clog,
|
||||
CommitLog: clog,
|
||||
Authorizer: authorizer,
|
||||
}
|
||||
if fn != nil {
|
||||
fn(cfg)
|
||||
|
|
@ -219,3 +224,32 @@ func testProduceConsumeStream(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testUnauthorized(
|
||||
t *testing.T,
|
||||
_,
|
||||
client api.LogClient,
|
||||
config *Config,
|
||||
) {
|
||||
ctx := context.Background()
|
||||
|
||||
produce, err := client.Produce(ctx,
|
||||
&api.ProduceRequest{
|
||||
Record: &api.Record{
|
||||
Value: []byte("hello world"),
|
||||
},
|
||||
},
|
||||
)
|
||||
require.Nil(t, produce, "produce response should be nil")
|
||||
gotCode, wantCode := status.Code(err), codes.PermissionDenied
|
||||
require.Equal(t, wantCode, gotCode,
|
||||
"produce error code when client is unauthorized should be permission denied")
|
||||
|
||||
consume, err := client.Consume(ctx, &api.ConsumeRequest{
|
||||
Offset: 0,
|
||||
})
|
||||
require.Nil(t, consume, "consume response should be nil")
|
||||
gotCode, wantCode = status.Code(err), codes.PermissionDenied
|
||||
require.Equal(t, wantCode, gotCode,
|
||||
"consume error code when client is unauthorized should be permission denied")
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue