proglog/internal/server/server_test.go

307 lines
7.2 KiB
Go
Raw Normal View History

package server
import (
"context"
2021-08-13 14:46:01 -05:00
"flag"
"io/ioutil"
"net"
2021-08-13 14:46:01 -05:00
"os"
"testing"
2021-08-13 14:46:01 -05:00
"time"
api "github.com/AYM1607/proglog/api/v1"
"github.com/AYM1607/proglog/internal/auth"
"github.com/AYM1607/proglog/internal/config"
"github.com/AYM1607/proglog/internal/log"
"github.com/stretchr/testify/require"
2021-08-13 14:46:01 -05:00
"go.opencensus.io/examples/exporter"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
)
2021-08-13 14:46:01 -05:00
var debug = flag.Bool("debug", false, "Enable observability for debugging.")
func TestMain(m *testing.M) {
flag.Parse()
if *debug {
logger, err := zap.NewDevelopment()
if err != nil {
panic(err)
}
zap.ReplaceGlobals(logger)
}
os.Exit(m.Run())
}
func TestServer(t *testing.T) {
for scenario, fn := range map[string]func(
t *testing.T,
rootClient api.LogClient,
nobodyClient api.LogClient,
config *Config,
){
"produce/consume a message to/from the log succeeds": testProduceConsume,
"produce/consume stream succeeds": testProduceConsumeStream,
"consume past a log boundary fails": testConsumePastBoundary,
"unauthorized fails": testUnauthorized,
} {
t.Run(scenario, func(t *testing.T) {
rootClient,
nobodyClient,
config,
teardown := setupTest(t, nil)
defer teardown()
fn(t, rootClient, nobodyClient, config)
})
}
}
func setupTest(t *testing.T, fn func(*Config)) (
rootClient api.LogClient,
nobodyClient api.LogClient,
cfg *Config,
teardown func(),
) {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
newClient := func(crtPath, keyPath string) (
*grpc.ClientConn,
api.LogClient,
[]grpc.DialOption,
) {
tlsConfig, err := config.SetupTLSConfig(config.TLSConfig{
CertFile: crtPath,
KeyFile: keyPath,
CAFile: config.CAFile,
Server: false,
})
require.NoError(t, err)
tlsCreds := credentials.NewTLS(tlsConfig)
opts := []grpc.DialOption{grpc.WithTransportCredentials(tlsCreds)}
conn, err := grpc.Dial(l.Addr().String(), opts...)
require.NoError(t, err)
client := api.NewLogClient(conn)
return conn, client, opts
}
2021-08-13 14:46:01 -05:00
// TODO: research relation of closures and shorthand variable declaration.
// If the connection is on created with var, traces don't work.
var rootConn *grpc.ClientConn
rootConn, rootClient, _ = newClient(
config.RootClientCertFile,
config.RootClientKeyFile,
)
2021-08-13 14:46:01 -05:00
// If the connection is on created with var, traces don't work.
var nobodyConn *grpc.ClientConn
nobodyConn, nobodyClient, _ = newClient(
config.NobodyClientCertFile,
config.NobodyClientKeyFile,
)
// Server config.
serverTLSConfig, err := config.SetupTLSConfig(config.TLSConfig{
2021-08-13 14:46:01 -05:00
CertFile: config.ServerCertFile,
KeyFile: config.ServerKeyFile,
CAFile: config.CAFile,
Server: true,
})
require.NoError(t, err)
serverCreds := credentials.NewTLS(serverTLSConfig)
dir, err := ioutil.TempDir("", "server-test")
require.NoError(t, err)
2021-08-13 14:46:01 -05:00
defer os.RemoveAll(dir)
clog, err := log.NewLog(dir, log.Config{})
require.NoError(t, err)
authorizer := auth.New(config.ACLModelFile, config.ACLPolicyFile)
2021-08-13 14:46:01 -05:00
var telemetryExporter *exporter.LogExporter
if *debug {
metricsLogFile, err := ioutil.TempFile("", "metrics-*.log")
require.NoError(t, err)
t.Logf("metrics log file: %s", metricsLogFile.Name())
tracesLogFile, err := ioutil.TempFile("", "traces-*.log")
require.NoError(t, err)
t.Logf("traces log file: %s", tracesLogFile.Name())
telemetryExporter, err = exporter.NewLogExporter(exporter.Options{
MetricsLogFile: metricsLogFile.Name(),
TracesLogFile: tracesLogFile.Name(),
ReportingInterval: time.Second,
})
require.NoError(t, err)
err = telemetryExporter.Start()
require.NoError(t, err)
}
cfg = &Config{
CommitLog: clog,
Authorizer: authorizer,
}
if fn != nil {
fn(cfg)
}
server, err := NewGRPCServer(cfg, grpc.Creds(serverCreds))
require.NoError(t, err)
go func() {
server.Serve(l)
}()
return rootClient, nobodyClient, cfg, func() {
server.Stop()
rootConn.Close()
nobodyConn.Close()
l.Close()
clog.Remove()
2021-08-13 14:46:01 -05:00
if telemetryExporter != nil {
time.Sleep(2000 * time.Millisecond)
telemetryExporter.Stop()
telemetryExporter.Close()
}
}
}
func testProduceConsume(t *testing.T, client, _ api.LogClient, config *Config) {
ctx := context.Background()
want := &api.Record{
Value: []byte("hello world"),
}
produce, err := client.Produce(
ctx,
&api.ProduceRequest{
Record: want,
},
)
require.NoError(t, err)
consume, err := client.Consume(ctx, &api.ConsumeRequest{
Offset: produce.Offset,
})
require.NoError(t, err)
require.Equal(t, produce.Offset, consume.Record.Offset)
require.Equal(t, want.Value, consume.Record.Value)
}
func testConsumePastBoundary(
t *testing.T,
client, _ api.LogClient,
config *Config,
) {
ctx := context.Background()
produce, err := client.Produce(ctx, &api.ProduceRequest{
Record: &api.Record{
Value: []byte("hello world"),
},
})
require.NoError(t, err)
consume, err := client.Consume(ctx, &api.ConsumeRequest{
Offset: produce.Offset + 1,
})
require.Nil(t, consume, "consume should be nil")
got := status.Code(err)
want := status.Code(api.ErrOffsetOutOfRange{}.GRPCStatus().Err())
require.Equal(t, want, got)
}
func testProduceConsumeStream(
t *testing.T,
client, _ api.LogClient,
config *Config,
) {
ctx := context.Background()
records := []*api.Record{{
Value: []byte("first message"),
Offset: 0,
}, {
Value: []byte("second message"),
Offset: 1,
}}
// Test Produce Stream.
// The code from the book adds an extra scope. Is it really needed?
{
stream, err := client.ProduceStream(ctx)
require.NoError(t, err)
// The log is empty so the slice index for reach record is also their offset.
for offset, record := range records {
err = stream.Send(&api.ProduceRequest{
Record: record,
})
require.NoError(t, err)
res, err := stream.Recv()
require.NoError(t, err)
require.Equal(t, uint64(offset), res.Offset)
}
}
// Test Consume stream.
// The code from the book adds an extra scope. Is it really needed?
{
stream, err := client.ConsumeStream(
ctx,
&api.ConsumeRequest{Offset: 0},
)
require.NoError(t, err)
for _, record := range records {
res, err := stream.Recv()
require.NoError(t, err)
// A record literal must be used otherwise the comparison fails.
require.Equal(t, &api.Record{
Value: record.Value,
Offset: record.Offset,
}, res.Record)
}
}
}
func testUnauthorized(
t *testing.T,
_,
client api.LogClient,
config *Config,
) {
ctx := context.Background()
produce, err := client.Produce(ctx,
&api.ProduceRequest{
Record: &api.Record{
Value: []byte("hello world"),
},
},
)
require.Nil(t, produce, "produce response should be nil")
gotCode, wantCode := status.Code(err), codes.PermissionDenied
require.Equal(t, wantCode, gotCode,
"produce error code when client is unauthorized should be permission denied")
consume, err := client.Consume(ctx, &api.ConsumeRequest{
Offset: 0,
})
require.Nil(t, consume, "consume response should be nil")
gotCode, wantCode = status.Code(err), codes.PermissionDenied
require.Equal(t, wantCode, gotCode,
"consume error code when client is unauthorized should be permission denied")
}