Modifies the tests to create 2 clients, one with each cert.

This commit is contained in:
Mariano Uvalle 2021-08-11 19:39:55 -05:00
parent 4634968521
commit fab55720e8

View file

@ -18,7 +18,8 @@ import (
func TestServer(t *testing.T) {
for scenario, fn := range map[string]func(
t *testing.T,
client api.LogClient,
rootClient api.LogClient,
nobodyClient api.LogClient,
config *Config,
){
"produce/consume a message to/from the log succeeds": testProduceConsume,
@ -26,15 +27,19 @@ func TestServer(t *testing.T) {
"consume past a log boundary fails": testConsumePastBoundary,
} {
t.Run(scenario, func(t *testing.T) {
client, config, teardown := setupTest(t, nil)
rootClient,
nobodyClient,
config,
teardown := setupTest(t, nil)
defer teardown()
fn(t, client, config)
fn(t, rootClient, nobodyClient, config)
})
}
}
func setupTest(t *testing.T, fn func(*Config)) (
client api.LogClient,
rootClient api.LogClient,
nobodyClient api.LogClient,
cfg *Config,
teardown func(),
) {
@ -43,21 +48,36 @@ func setupTest(t *testing.T, fn func(*Config)) (
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
// Client config.
clientTLSConfig, err := config.SetupTLSConfig(config.TLSConfig{
CertFile: config.ClientCertFile,
KeyFile: config.ClientKeyFile,
CAFile: config.CAFile,
})
require.NoError(t, err)
newClient := func(crtPath, keyPath string) (
*grpc.ClientConn,
api.LogClient,
[]grpc.DialOption,
) {
tlsConfig, err := config.SetupTLSConfig(config.TLSConfig{
CertFile: crtPath,
KeyFile: keyPath,
CAFile: config.CAFile,
Server: false,
})
require.NoError(t, err)
clientCreds := credentials.NewTLS(clientTLSConfig)
cc, err := grpc.Dial(
l.Addr().String(),
grpc.WithTransportCredentials(clientCreds),
tlsCreds := credentials.NewTLS(tlsConfig)
opts := []grpc.DialOption{grpc.WithTransportCredentials(tlsCreds)}
conn, err := grpc.Dial(l.Addr().String(), opts...)
require.NoError(t, err)
client := api.NewLogClient(conn)
return conn, client, opts
}
rootConn, rootClient, _ := newClient(
config.RootClientCertFile,
config.RootClientKeyFile,
)
nobodyConn, nobodyClient, _ := newClient(
config.NobodyClientCertFile,
config.NobodyClientKeyFile,
)
require.NoError(t, err)
client = api.NewLogClient(cc)
// Server config.
serverTLSConfig, err := config.SetupTLSConfig(config.TLSConfig{
@ -89,15 +109,16 @@ func setupTest(t *testing.T, fn func(*Config)) (
server.Serve(l)
}()
return client, cfg, func() {
return rootClient, nobodyClient, cfg, func() {
server.Stop()
cc.Close()
rootConn.Close()
nobodyConn.Close()
l.Close()
clog.Remove()
}
}
func testProduceConsume(t *testing.T, client api.LogClient, config *Config) {
func testProduceConsume(t *testing.T, client, _ api.LogClient, config *Config) {
ctx := context.Background()
want := &api.Record{
@ -123,7 +144,7 @@ func testProduceConsume(t *testing.T, client api.LogClient, config *Config) {
func testConsumePastBoundary(
t *testing.T,
client api.LogClient,
client, _ api.LogClient,
config *Config,
) {
ctx := context.Background()
@ -146,7 +167,7 @@ func testConsumePastBoundary(
func testProduceConsumeStream(
t *testing.T,
client api.LogClient,
client, _ api.LogClient,
config *Config,
) {
ctx := context.Background()