diff --git a/pkg/ast/call.go b/pkg/ast/call.go new file mode 100644 index 0000000..5d87790 --- /dev/null +++ b/pkg/ast/call.go @@ -0,0 +1,31 @@ +package ast + +import ( + "bytes" + "strings" + + "code.jmug.me/jmug/interpreter-in-go/pkg/token" +) + +type CallExpression struct { + Token token.Token // The ( token + Function Expression + Arguments []Expression +} + +func (ce *CallExpression) expressionNode() {} +func (ce *CallExpression) TokenLiteral() string { + return ce.Token.Literal +} +func (ce *CallExpression) String() string { + var out bytes.Buffer + out.WriteString(ce.Function.String()) + out.WriteString("(") + args := []string{} + for _, arg := range ce.Arguments { + args = append(args, arg.String()) + } + out.WriteString(strings.Join(args, ", ")) + out.WriteString(")") + return out.String() +} diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index f7a985e..4fa3f85 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -49,6 +49,7 @@ func New(l *lexer.Lexer) *Parser { p.registerInfix(token.LT, p.parseInfixExpression) p.registerInfix(token.EQ, p.parseInfixExpression) p.registerInfix(token.NOT_EQ, p.parseInfixExpression) + p.registerInfix(token.LPAREN, p.parseCallExpression) // TODO: figure out why this can't be done from `parseProgram` p.nextToken() p.nextToken() @@ -90,7 +91,7 @@ func (p *Parser) parseBlockStatement() *ast.BlockStatement { if stmt != nil { block.Statements = append(block.Statements, stmt) } - // Consume the semicolon. + // Consume the last token in the statement. p.nextToken() } return block @@ -105,8 +106,10 @@ func (p *Parser) parseLetStatement() ast.Statement { if !p.nextTokenIfPeekIs(token.ASSIGN) { return nil } - // TODO: Skipping until we find the semicolon to avoid parsing the expression. - for !p.curTokenIs(token.SEMICOLON) { + // Consume the assign. + p.nextToken() + stmt.Value = p.parseExpression(LOWEST) + if p.peekTokenIs(token.SEMICOLON) { p.nextToken() } return stmt @@ -115,8 +118,8 @@ func (p *Parser) parseLetStatement() ast.Statement { func (p *Parser) parseReturnStatement() ast.Statement { stmt := &ast.ReturnStatement{Token: p.curToken} p.nextToken() - // TODO: Skipping until we find the semicolon to avoid parsing the expression. - for !p.curTokenIs(token.SEMICOLON) { + stmt.ReturnValue = p.parseExpression(LOWEST) + if p.peekTokenIs(token.SEMICOLON) { p.nextToken() } return stmt @@ -272,6 +275,35 @@ func (p *Parser) parseFunctionParameters() []*ast.Identifier { return params } +func (p *Parser) parseCallExpression(function ast.Expression) ast.Expression { + call := &ast.CallExpression{Token: p.curToken, Function: function} + call.Arguments = p.parseCallArguments() + return call +} + +func (p *Parser) parseCallArguments() []ast.Expression { + args := []ast.Expression{} + if p.peekTokenIs(token.RPAREN) { + p.nextToken() + return args + } + // Consume the LPAREN + p.nextToken() + args = append(args, p.parseExpression(LOWEST)) + for p.peekTokenIs(token.COMMA) { + // Consume last token of the previous expression. + p.nextToken() + // Consume the comma. + p.nextToken() + args = append(args, p.parseExpression(LOWEST)) + } + if !p.nextTokenIfPeekIs(token.RPAREN) { + // TODO: Would be good to emit an error here. + return nil + } + return args +} + func (p *Parser) curTokenIs(typ token.TokenType) bool { return p.curToken.Type == typ } diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index 46094cc..989115a 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -9,37 +9,36 @@ import ( ) func TestLetStatements(t *testing.T) { - input := ` -let x = 5; -let y = 10; -let foobar = 838383; - ` - l := lexer.New(input) - p := New(l) - - program := p.ParseProgram() - checkParserErrors(t, p) - if program == nil { - t.Fatalf("ParseProgram() returned nil") - } - if len(program.Statements) != 3 { - t.Fatalf("program.Statements does not contain 3 statements. got=%d", - len(program.Statements)) - } - tests := []struct { + input string expectedIdentifier string + expectedValue any }{ - {"x"}, - {"y"}, - {"foobar"}, + {"let x = 5;", "x", 5}, + {"let y = true;", "y", true}, + {"let foobar = y;", "foobar", "y"}, } - for i, tt := range tests { - stmt := program.Statements[i] + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain 1 statements. got=%d", + len(program.Statements)) + } + + stmt := program.Statements[0] if !testLetStatement(t, stmt, tt.expectedIdentifier) { return } + + val := stmt.(*ast.LetStatement).Value + if !testLiteralExpression(t, val, tt.expectedValue) { + return + } } } @@ -70,32 +69,38 @@ func testLetStatement(t *testing.T, s ast.Statement, name string) bool { } func TestReturnStatements(t *testing.T) { - input := ` -return 5; -return 10; -return 993322; -` - l := lexer.New(input) - p := New(l) - - program := p.ParseProgram() - checkParserErrors(t, p) - - if len(program.Statements) != 3 { - t.Fatalf("program.Statements does not contain 3 statements. got=%d", - len(program.Statements)) + tests := []struct { + input string + expectedValue interface{} + }{ + {"return 5;", 5}, + {"return true;", true}, + {"return foobar;", "foobar"}, } - for _, stmt := range program.Statements { + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain 1 statements. got=%d", + len(program.Statements)) + } + + stmt := program.Statements[0] returnStmt, ok := stmt.(*ast.ReturnStatement) if !ok { - t.Errorf("stmt not *ast.ReturnStatement. got=%T", stmt) - continue + t.Fatalf("stmt not *ast.ReturnStatement. got=%T", stmt) } if returnStmt.TokenLiteral() != "return" { - t.Errorf("returnStmt.TokenLiteral not 'return', got %q", + t.Fatalf("returnStmt.TokenLiteral not 'return', got %q", returnStmt.TokenLiteral()) } + if testLiteralExpression(t, returnStmt.ReturnValue, tt.expectedValue) { + return + } } } @@ -336,6 +341,18 @@ func TestOperatorPrecedenceParsing(t *testing.T) { "!(true == true)", "(!(true == true))", }, + { + "a + add(b * c) + d", + "((a + add((b * c))) + d)", + }, + { + "add(a, b, 1, 2 * 3, 4 + 5, add(6, 7 * 8))", + "add(a, b, 1, (2 * 3), (4 + 5), add(6, (7 * 8)))", + }, + { + "add(a + b + c * d / f + g)", + "add((((a + b) + ((c * d) / f)) + g))", + }, } for _, tt := range tests { @@ -573,6 +590,44 @@ func TestFunctionParameterParsing(t *testing.T) { } } +func TestCallExpressionParsing(t *testing.T) { + input := "add(1, 2 * 3, 4 + 5);" + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain %d statements. got=%d\n", + 1, len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("stmt is not ast.ExpressionStatement. got=%T", + program.Statements[0]) + } + + exp, ok := stmt.Expression.(*ast.CallExpression) + if !ok { + t.Fatalf("stmt.Expression is not ast.CallExpression. got=%T", + stmt.Expression) + } + + if !testIdentifier(t, exp.Function, "add") { + return + } + + if len(exp.Arguments) != 3 { + t.Fatalf("wrong length of arguments. got=%d", len(exp.Arguments)) + } + + testLiteralExpression(t, exp.Arguments[0], 1) + testInfixExpression(t, exp.Arguments[1], 2, "*", 3) + testInfixExpression(t, exp.Arguments[2], 4, "+", 5) +} + func testIdentifier(t *testing.T, exp ast.Expression, value string) bool { ident, ok := exp.(*ast.Identifier) if !ok { diff --git a/pkg/parser/precedence.go b/pkg/parser/precedence.go index 595bfb5..cd93b22 100644 --- a/pkg/parser/precedence.go +++ b/pkg/parser/precedence.go @@ -24,6 +24,7 @@ var precedences = map[token.TokenType]int{ token.MINUS: SUM, token.ASTERISK: PRODUCT, token.SLASH: PRODUCT, + token.LPAREN: CALL, } func (p *Parser) peekPrecedence() int {