Parse function literals.

Signed-off-by: jmug <u.g.a.mariano@gmail.com>
This commit is contained in:
Mariano Uvalle 2025-01-05 16:13:08 -08:00
parent 9e9324bb56
commit 985cf24fbc
3 changed files with 149 additions and 0 deletions

32
pkg/ast/function.go Normal file
View file

@ -0,0 +1,32 @@
package ast
import (
"bytes"
"strings"
"code.jmug.me/jmug/interpreter-in-go/pkg/token"
)
type FunctionLiteral struct {
Token token.Token // The fn token
Parameters []*Identifier
Body *BlockStatement
}
func (fl *FunctionLiteral) expressionNode() {}
func (fl *FunctionLiteral) TokenLiteral() string {
return fl.Token.Literal
}
func (fl *FunctionLiteral) String() string {
var out bytes.Buffer
params := []string{}
for _, p := range fl.Parameters {
params = append(params, p.String())
}
out.WriteString(fl.TokenLiteral())
out.WriteString("(")
out.WriteString(strings.Join(params, ", "))
out.WriteString(") ")
out.WriteString(fl.Body.String())
return out.String()
}

View file

@ -39,6 +39,7 @@ func New(l *lexer.Lexer) *Parser {
p.registerPrefix(token.FALSE, p.parseBoolean)
p.registerPrefix(token.LPAREN, p.parseGroupedExpression)
p.registerPrefix(token.IF, p.parseIfExpression)
p.registerPrefix(token.FUNCTION, p.parseFunctionLiteral)
// Infix registrations
p.registerInfix(token.PLUS, p.parseInfixExpression)
p.registerInfix(token.MINUS, p.parseInfixExpression)
@ -233,6 +234,44 @@ func (p *Parser) parseIfExpression() ast.Expression {
return exp
}
func (p *Parser) parseFunctionLiteral() ast.Expression {
fn := &ast.FunctionLiteral{Token: p.curToken}
if !p.nextTokenIfPeekIs(token.LPAREN) {
// TODO: Would be good to emit an error here.
return nil
}
fn.Parameters = p.parseFunctionParameters()
if !p.nextTokenIfPeekIs(token.LBRACE) {
// TODO: Would be good to emit an error here.
return nil
}
fn.Body = p.parseBlockStatement()
return fn
}
func (p *Parser) parseFunctionParameters() []*ast.Identifier {
params := []*ast.Identifier{}
if p.peekTokenIs(token.RPAREN) {
p.nextToken()
return params
}
// Consume the LPAREN
p.nextToken()
params = append(params, &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal})
for p.peekTokenIs(token.COMMA) {
// Consume the previous identifier.
p.nextToken()
// Consume the comma.
p.nextToken()
params = append(params, &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal})
}
if !p.nextTokenIfPeekIs(token.RPAREN) {
// TODO: Would be good to emit an error here.
return nil
}
return params
}
func (p *Parser) curTokenIs(typ token.TokenType) bool {
return p.curToken.Type == typ
}

View file

@ -495,6 +495,84 @@ func TestIfElseExpression(t *testing.T) {
return
}
}
func TestFunctionLiteralParsing(t *testing.T) {
input := `fn(x, y) { x + y; }`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
1, len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
function, ok := stmt.Expression.(*ast.FunctionLiteral)
if !ok {
t.Fatalf("stmt.Expression is not ast.FunctionLiteral. got=%T",
stmt.Expression)
}
if len(function.Parameters) != 2 {
t.Fatalf("function literal parameters wrong. want 2, got=%d\n",
len(function.Parameters))
}
testLiteralExpression(t, function.Parameters[0], "x")
testLiteralExpression(t, function.Parameters[1], "y")
if len(function.Body.Statements) != 1 {
t.Fatalf("function.Body.Statements has not 1 statements. got=%d\n",
len(function.Body.Statements))
}
bodyStmt, ok := function.Body.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("function body stmt is not ast.ExpressionStatement. got=%T",
function.Body.Statements[0])
}
testInfixExpression(t, bodyStmt.Expression, "x", "+", "y")
}
func TestFunctionParameterParsing(t *testing.T) {
tests := []struct {
input string
expectedParams []string
}{
{input: "fn() {};", expectedParams: []string{}},
{input: "fn(x) {};", expectedParams: []string{"x"}},
{input: "fn(x, y, z) {};", expectedParams: []string{"x", "y", "z"}},
}
for _, tt := range tests {
l := lexer.New(tt.input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt := program.Statements[0].(*ast.ExpressionStatement)
function := stmt.Expression.(*ast.FunctionLiteral)
if len(function.Parameters) != len(tt.expectedParams) {
t.Errorf("length parameters wrong. want %d, got=%d\n",
len(tt.expectedParams), len(function.Parameters))
}
for i, ident := range tt.expectedParams {
testLiteralExpression(t, function.Parameters[i], ident)
}
}
}
func testIdentifier(t *testing.T, exp ast.Expression, value string) bool {
ident, ok := exp.(*ast.Identifier)
if !ok {