diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go index 08b1ac7..51c78c0 100644 --- a/pkg/ast/ast.go +++ b/pkg/ast/ast.go @@ -158,3 +158,16 @@ func (ie *InfixExpression) TokenLiteral() string { func (ie *InfixExpression) String() string { return "(" + ie.Left.String() + " " + ie.Operator + " " + ie.Right.String() + ")" } + +type Boolean struct { + Token token.Token + Value bool +} + +func (bl *Boolean) expressionNode() {} +func (bl *Boolean) TokenLiteral() string { + return bl.Token.Literal +} +func (bl *Boolean) String() string { + return bl.Token.Literal +} diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 2fec255..1ba76c2 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -35,6 +35,8 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefix(token.INT, p.parseIntegerLiteral) p.registerPrefix(token.MINUS, p.parsePrefixExpression) p.registerPrefix(token.BANG, p.parsePrefixExpression) + p.registerPrefix(token.TRUE, p.parseBoolean) + p.registerPrefix(token.FALSE, p.parseBoolean) // Infix registrations p.registerInfix(token.PLUS, p.parseInfixExpression) p.registerInfix(token.MINUS, p.parseInfixExpression) @@ -148,6 +150,10 @@ func (p *Parser) parseIntegerLiteral() ast.Expression { return exp } +func (p *Parser) parseBoolean() ast.Expression { + return &ast.Boolean{Token: p.curToken, Value: p.curTokenIs(token.TRUE)} +} + func (p *Parser) parsePrefixExpression() ast.Expression { exp := &ast.PrefixExpression{ Token: p.curToken, diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index 48002f3..b96dfbe 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -161,12 +161,14 @@ func TestIntegerLiteralExpression(t *testing.T) { func TestParsingPrefixExpressions(t *testing.T) { prefixTests := []struct { - input string - operator string - integerValue int64 + input string + operator string + value any }{ {"!5;", "!", 5}, {"-15;", "-", 15}, + {"!true;", "!", true}, + {"!false;", "!", false}, } for _, tt := range prefixTests { @@ -194,7 +196,7 @@ func TestParsingPrefixExpressions(t *testing.T) { t.Fatalf("exp.Operator is not '%s'. got=%s", tt.operator, exp.Operator) } - if !testIntegerLiteral(t, exp.Right, tt.integerValue) { + if !testLiteralExpression(t, exp.Right, tt.value) { return } } @@ -205,9 +207,9 @@ func TestParsingPrefixExpressions(t *testing.T) { func TestParsingInfixExpressions(t *testing.T) { infixTests := []struct { input string - leftValue int64 + leftValue any operator string - rightValue int64 + rightValue any }{ {"5 + 5;", 5, "+", 5}, {"5 - 5;", 5, "-", 5}, @@ -217,6 +219,9 @@ func TestParsingInfixExpressions(t *testing.T) { {"5 < 5;", 5, "<", 5}, {"5 == 5;", 5, "==", 5}, {"5 != 5;", 5, "!=", 5}, + {"true == true", true, "==", true}, + {"true != false", true, "!=", false}, + {"false == false", false, "==", false}, } for _, tt := range infixTests { @@ -236,21 +241,7 @@ func TestParsingInfixExpressions(t *testing.T) { program.Statements[0]) } - exp, ok := stmt.Expression.(*ast.InfixExpression) - if !ok { - t.Fatalf("exp is not ast.InfixExpression. got=%T", stmt.Expression) - } - - if !testIntegerLiteral(t, exp.Left, tt.leftValue) { - return - } - - if exp.Operator != tt.operator { - t.Fatalf("exp.Operator is not '%s'. got=%s", - tt.operator, exp.Operator) - } - - if !testIntegerLiteral(t, exp.Right, tt.rightValue) { + if !testInfixExpression(t, stmt.Expression, tt.leftValue, tt.operator, tt.rightValue) { return } } @@ -309,6 +300,22 @@ func TestOperatorPrecedenceParsing(t *testing.T) { "3 + 4 * 5 == 3 * 1 + 4 * 5", "((3 + (4 * 5)) == ((3 * 1) + (4 * 5)))", }, + { + "true", + "true", + }, + { + "false", + "false", + }, + { + "3 > 5 == false", + "((3 > 5) == false)", + }, + { + "3 < 5 == true", + "((3 < 5) == true)", + }, } for _, tt := range tests { @@ -324,6 +331,129 @@ func TestOperatorPrecedenceParsing(t *testing.T) { } } +func TestBooleanExpression(t *testing.T) { + tests := []struct { + input string + expectedBoolean bool + }{ + {"true;", true}, + {"false;", false}, + } + + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program has not enough statements. got=%d", + len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", + program.Statements[0]) + } + + boolean, ok := stmt.Expression.(*ast.Boolean) + if !ok { + t.Fatalf("exp not *ast.Boolean. got=%T", stmt.Expression) + } + if boolean.Value != tt.expectedBoolean { + t.Errorf("boolean.Value not %t. got=%t", tt.expectedBoolean, + boolean.Value) + } + } +} + +func testIdentifier(t *testing.T, exp ast.Expression, value string) bool { + ident, ok := exp.(*ast.Identifier) + if !ok { + t.Errorf("exp not *ast.Identifier. got=%T", exp) + return false + } + + if ident.Value != value { + t.Errorf("ident.Value not %s. got=%s", value, ident.Value) + return false + } + + if ident.TokenLiteral() != value { + t.Errorf("ident.TokenLiteral not %s. got=%s", value, + ident.TokenLiteral()) + return false + } + + return true +} + +func testBooleanLiteral(t *testing.T, exp ast.Expression, value bool) bool { + bo, ok := exp.(*ast.Boolean) + if !ok { + t.Errorf("exp not *ast.Boolean. got=%T", exp) + return false + } + + if bo.Value != value { + t.Errorf("bo.Value not %t. got=%t", value, bo.Value) + return false + } + + if bo.TokenLiteral() != fmt.Sprintf("%t", value) { + t.Errorf("bo.TokenLiteral not %t. got=%s", + value, bo.TokenLiteral()) + return false + } + + return true +} + +func testLiteralExpression( + t *testing.T, + exp ast.Expression, + expected any, +) bool { + switch v := expected.(type) { + case int: + return testIntegerLiteral(t, exp, int64(v)) + case int64: + return testIntegerLiteral(t, exp, v) + case string: + return testIdentifier(t, exp, v) + case bool: + return testBooleanLiteral(t, exp, v) + } + t.Errorf("type of exp not handled. got=%T", exp) + return false +} + +func testInfixExpression(t *testing.T, exp ast.Expression, left any, + operator string, right any) bool { + + opExp, ok := exp.(*ast.InfixExpression) + if !ok { + t.Errorf("exp is not ast.InfixExpression. got=%T(%s)", exp, exp) + return false + } + + if !testLiteralExpression(t, opExp.Left, left) { + return false + } + + if opExp.Operator != operator { + t.Errorf("exp.Operator is not '%s'. got=%q", operator, opExp.Operator) + return false + } + + if !testLiteralExpression(t, opExp.Right, right) { + return false + } + + return true +} + func testIntegerLiteral(t *testing.T, il ast.Expression, value int64) bool { integ, ok := il.(*ast.IntegerLiteral) if !ok {