diff --git a/pkg/ast/function.go b/pkg/ast/function.go new file mode 100644 index 0000000..8cdfb9e --- /dev/null +++ b/pkg/ast/function.go @@ -0,0 +1,32 @@ +package ast + +import ( + "bytes" + "strings" + + "code.jmug.me/jmug/interpreter-in-go/pkg/token" +) + +type FunctionLiteral struct { + Token token.Token // The fn token + Parameters []*Identifier + Body *BlockStatement +} + +func (fl *FunctionLiteral) expressionNode() {} +func (fl *FunctionLiteral) TokenLiteral() string { + return fl.Token.Literal +} +func (fl *FunctionLiteral) String() string { + var out bytes.Buffer + params := []string{} + for _, p := range fl.Parameters { + params = append(params, p.String()) + } + out.WriteString(fl.TokenLiteral()) + out.WriteString("(") + out.WriteString(strings.Join(params, ", ")) + out.WriteString(") ") + out.WriteString(fl.Body.String()) + return out.String() +} diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 45a220d..f7a985e 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -39,6 +39,7 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefix(token.FALSE, p.parseBoolean) p.registerPrefix(token.LPAREN, p.parseGroupedExpression) p.registerPrefix(token.IF, p.parseIfExpression) + p.registerPrefix(token.FUNCTION, p.parseFunctionLiteral) // Infix registrations p.registerInfix(token.PLUS, p.parseInfixExpression) p.registerInfix(token.MINUS, p.parseInfixExpression) @@ -233,6 +234,44 @@ func (p *Parser) parseIfExpression() ast.Expression { return exp } +func (p *Parser) parseFunctionLiteral() ast.Expression { + fn := &ast.FunctionLiteral{Token: p.curToken} + if !p.nextTokenIfPeekIs(token.LPAREN) { + // TODO: Would be good to emit an error here. + return nil + } + fn.Parameters = p.parseFunctionParameters() + if !p.nextTokenIfPeekIs(token.LBRACE) { + // TODO: Would be good to emit an error here. + return nil + } + fn.Body = p.parseBlockStatement() + return fn +} + +func (p *Parser) parseFunctionParameters() []*ast.Identifier { + params := []*ast.Identifier{} + if p.peekTokenIs(token.RPAREN) { + p.nextToken() + return params + } + // Consume the LPAREN + p.nextToken() + params = append(params, &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal}) + for p.peekTokenIs(token.COMMA) { + // Consume the previous identifier. + p.nextToken() + // Consume the comma. + p.nextToken() + params = append(params, &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal}) + } + if !p.nextTokenIfPeekIs(token.RPAREN) { + // TODO: Would be good to emit an error here. + return nil + } + return params +} + func (p *Parser) curTokenIs(typ token.TokenType) bool { return p.curToken.Type == typ } diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index 0f200a9..46094cc 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -495,6 +495,84 @@ func TestIfElseExpression(t *testing.T) { return } } + +func TestFunctionLiteralParsing(t *testing.T) { + input := `fn(x, y) { x + y; }` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain %d statements. got=%d\n", + 1, len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", + program.Statements[0]) + } + + function, ok := stmt.Expression.(*ast.FunctionLiteral) + if !ok { + t.Fatalf("stmt.Expression is not ast.FunctionLiteral. got=%T", + stmt.Expression) + } + + if len(function.Parameters) != 2 { + t.Fatalf("function literal parameters wrong. want 2, got=%d\n", + len(function.Parameters)) + } + + testLiteralExpression(t, function.Parameters[0], "x") + testLiteralExpression(t, function.Parameters[1], "y") + + if len(function.Body.Statements) != 1 { + t.Fatalf("function.Body.Statements has not 1 statements. got=%d\n", + len(function.Body.Statements)) + } + + bodyStmt, ok := function.Body.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("function body stmt is not ast.ExpressionStatement. got=%T", + function.Body.Statements[0]) + } + + testInfixExpression(t, bodyStmt.Expression, "x", "+", "y") +} + +func TestFunctionParameterParsing(t *testing.T) { + tests := []struct { + input string + expectedParams []string + }{ + {input: "fn() {};", expectedParams: []string{}}, + {input: "fn(x) {};", expectedParams: []string{"x"}}, + {input: "fn(x, y, z) {};", expectedParams: []string{"x", "y", "z"}}, + } + + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt := program.Statements[0].(*ast.ExpressionStatement) + function := stmt.Expression.(*ast.FunctionLiteral) + + if len(function.Parameters) != len(tt.expectedParams) { + t.Errorf("length parameters wrong. want %d, got=%d\n", + len(tt.expectedParams), len(function.Parameters)) + } + + for i, ident := range tt.expectedParams { + testLiteralExpression(t, function.Parameters[i], ident) + } + } +} + func testIdentifier(t *testing.T, exp ast.Expression, value string) bool { ident, ok := exp.(*ast.Identifier) if !ok {