From fab55720e8638f5487edf5885b3812fbd2e86a6e Mon Sep 17 00:00:00 2001 From: AYM1607 Date: Wed, 11 Aug 2021 19:39:55 -0500 Subject: [PATCH] Modifies the tests to create 2 clients, one with each cert. --- internal/server/server_test.go | 65 ++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 6bf041a..d7cadde 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -18,7 +18,8 @@ import ( func TestServer(t *testing.T) { for scenario, fn := range map[string]func( t *testing.T, - client api.LogClient, + rootClient api.LogClient, + nobodyClient api.LogClient, config *Config, ){ "produce/consume a message to/from the log succeeds": testProduceConsume, @@ -26,15 +27,19 @@ func TestServer(t *testing.T) { "consume past a log boundary fails": testConsumePastBoundary, } { t.Run(scenario, func(t *testing.T) { - client, config, teardown := setupTest(t, nil) + rootClient, + nobodyClient, + config, + teardown := setupTest(t, nil) defer teardown() - fn(t, client, config) + fn(t, rootClient, nobodyClient, config) }) } } func setupTest(t *testing.T, fn func(*Config)) ( - client api.LogClient, + rootClient api.LogClient, + nobodyClient api.LogClient, cfg *Config, teardown func(), ) { @@ -43,21 +48,36 @@ func setupTest(t *testing.T, fn func(*Config)) ( l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - // Client config. - clientTLSConfig, err := config.SetupTLSConfig(config.TLSConfig{ - CertFile: config.ClientCertFile, - KeyFile: config.ClientKeyFile, - CAFile: config.CAFile, - }) - require.NoError(t, err) + newClient := func(crtPath, keyPath string) ( + *grpc.ClientConn, + api.LogClient, + []grpc.DialOption, + ) { + tlsConfig, err := config.SetupTLSConfig(config.TLSConfig{ + CertFile: crtPath, + KeyFile: keyPath, + CAFile: config.CAFile, + Server: false, + }) + require.NoError(t, err) - clientCreds := credentials.NewTLS(clientTLSConfig) - cc, err := grpc.Dial( - l.Addr().String(), - grpc.WithTransportCredentials(clientCreds), + tlsCreds := credentials.NewTLS(tlsConfig) + opts := []grpc.DialOption{grpc.WithTransportCredentials(tlsCreds)} + conn, err := grpc.Dial(l.Addr().String(), opts...) + require.NoError(t, err) + client := api.NewLogClient(conn) + return conn, client, opts + } + + rootConn, rootClient, _ := newClient( + config.RootClientCertFile, + config.RootClientKeyFile, + ) + + nobodyConn, nobodyClient, _ := newClient( + config.NobodyClientCertFile, + config.NobodyClientKeyFile, ) - require.NoError(t, err) - client = api.NewLogClient(cc) // Server config. serverTLSConfig, err := config.SetupTLSConfig(config.TLSConfig{ @@ -89,15 +109,16 @@ func setupTest(t *testing.T, fn func(*Config)) ( server.Serve(l) }() - return client, cfg, func() { + return rootClient, nobodyClient, cfg, func() { server.Stop() - cc.Close() + rootConn.Close() + nobodyConn.Close() l.Close() clog.Remove() } } -func testProduceConsume(t *testing.T, client api.LogClient, config *Config) { +func testProduceConsume(t *testing.T, client, _ api.LogClient, config *Config) { ctx := context.Background() want := &api.Record{ @@ -123,7 +144,7 @@ func testProduceConsume(t *testing.T, client api.LogClient, config *Config) { func testConsumePastBoundary( t *testing.T, - client api.LogClient, + client, _ api.LogClient, config *Config, ) { ctx := context.Background() @@ -146,7 +167,7 @@ func testConsumePastBoundary( func testProduceConsumeStream( t *testing.T, - client api.LogClient, + client, _ api.LogClient, config *Config, ) { ctx := context.Background()