diff --git a/pkg/ast/block.go b/pkg/ast/block.go new file mode 100644 index 0000000..108dadc --- /dev/null +++ b/pkg/ast/block.go @@ -0,0 +1,24 @@ +package ast + +import ( + "bytes" + + "code.jmug.me/jmug/interpreter-in-go/pkg/token" +) + +type BlockStatement struct { + Token token.Token // The `{` token. + Statements []Statement +} + +func (bs *BlockStatement) statementNode() {} +func (bs *BlockStatement) TokenLiteral() string { + return bs.Token.Literal +} +func (bs *BlockStatement) String() string { + var out bytes.Buffer + for _, s := range bs.Statements { + out.WriteString(s.String()) + } + return out.String() +} diff --git a/pkg/ast/if_expression.go b/pkg/ast/if_expression.go new file mode 100644 index 0000000..e28d122 --- /dev/null +++ b/pkg/ast/if_expression.go @@ -0,0 +1,31 @@ +package ast + +import ( + "bytes" + + "code.jmug.me/jmug/interpreter-in-go/pkg/token" +) + +type IfExpression struct { + Token token.Token // The `if` token. + Condition Expression + Consequence *BlockStatement + Alternative *BlockStatement +} + +func (ie *IfExpression) expressionNode() {} +func (ie *IfExpression) TokenLiteral() string { + return ie.Token.Literal +} +func (ie *IfExpression) String() string { + var out bytes.Buffer + out.WriteString("if") + out.WriteString(ie.Condition.String()) + out.WriteString(" ") + out.WriteString(ie.Consequence.String()) + if ie.Alternative != nil { + out.WriteString("else ") + out.WriteString(ie.Alternative.String()) + } + return out.String() +} diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 9223f43..45a220d 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -38,6 +38,7 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefix(token.TRUE, p.parseBoolean) p.registerPrefix(token.FALSE, p.parseBoolean) p.registerPrefix(token.LPAREN, p.parseGroupedExpression) + p.registerPrefix(token.IF, p.parseIfExpression) // Infix registrations p.registerInfix(token.PLUS, p.parseInfixExpression) p.registerInfix(token.MINUS, p.parseInfixExpression) @@ -79,13 +80,28 @@ func (p *Parser) parseStatement() ast.Statement { return p.parseExpressionStatement() } +func (p *Parser) parseBlockStatement() *ast.BlockStatement { + block := &ast.BlockStatement{Token: p.curToken} + block.Statements = []ast.Statement{} + p.nextToken() + for !p.curTokenIs(token.RBRACE) && !p.curTokenIs(token.EOF) { + stmt := p.parseStatement() + if stmt != nil { + block.Statements = append(block.Statements, stmt) + } + // Consume the semicolon. + p.nextToken() + } + return block +} + func (p *Parser) parseLetStatement() ast.Statement { stmt := &ast.LetStatement{Token: p.curToken} - if !p.expectPeek(token.IDENT) { + if !p.nextTokenIfPeekIs(token.IDENT) { return nil } stmt.Name = &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal} - if !p.expectPeek(token.ASSIGN) { + if !p.nextTokenIfPeekIs(token.ASSIGN) { return nil } // TODO: Skipping until we find the semicolon to avoid parsing the expression. @@ -180,13 +196,43 @@ func (p *Parser) parseInfixExpression(left ast.Expression) ast.Expression { func (p *Parser) parseGroupedExpression() ast.Expression { p.nextToken() exp := p.parseExpression(LOWEST) - if !p.expectPeek(token.RPAREN) { + if !p.nextTokenIfPeekIs(token.RPAREN) { // TODO: Would probably be good to emit an error here? return nil } return exp } +func (p *Parser) parseIfExpression() ast.Expression { + exp := &ast.IfExpression{Token: p.curToken} + if !p.nextTokenIfPeekIs(token.LPAREN) { + // TODO: Would be good to emit an error here. + return nil + } + p.nextToken() + exp.Condition = p.parseExpression(LOWEST) + if !p.nextTokenIfPeekIs(token.RPAREN) { + // TODO: Would be good to emit an error here. + return nil + } + if !p.nextTokenIfPeekIs(token.LBRACE) { + // TODO: Would be good to emit an error here. + return nil + } + exp.Consequence = p.parseBlockStatement() + if p.peekTokenIs(token.ELSE) { + p.nextToken() + if !p.nextTokenIfPeekIs(token.LBRACE) { + // TODO: Would be good to emit an error here. + return nil + } + exp.Alternative = p.parseBlockStatement() + } + // We don't consume the RBRACE because it acts as our "end of statement" + // token, and it's consumed by parseProgram. + return exp +} + func (p *Parser) curTokenIs(typ token.TokenType) bool { return p.curToken.Type == typ } @@ -198,7 +244,7 @@ func (p *Parser) peekTokenIs(typ token.TokenType) bool { // NOTE: I'll leave the name as-is to avoid deviating from the book (maybe a // rename at the end?), but I think `nextTokenIfPeek` would be a much better // name for this. -func (p *Parser) expectPeek(typ token.TokenType) bool { +func (p *Parser) nextTokenIfPeekIs(typ token.TokenType) bool { if p.peekTokenIs(typ) { p.nextToken() return true diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index ca2a979..0f200a9 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -388,6 +388,113 @@ func TestBooleanExpression(t *testing.T) { } } +func TestIfExpression(t *testing.T) { + input := `if (x < y) { x }` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain %d statements. got=%d\n", + 1, len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", + program.Statements[0]) + } + + exp, ok := stmt.Expression.(*ast.IfExpression) + if !ok { + t.Fatalf("stmt.Expression is not ast.IfExpression. got=%T", + stmt.Expression) + } + + if !testInfixExpression(t, exp.Condition, "x", "<", "y") { + return + } + + if len(exp.Consequence.Statements) != 1 { + t.Errorf("consequence is not 1 statements. got=%d\n", + len(exp.Consequence.Statements)) + } + + consequence, ok := exp.Consequence.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T", + exp.Consequence.Statements[0]) + } + + if !testIdentifier(t, consequence.Expression, "x") { + return + } + + if exp.Alternative != nil { + t.Errorf("exp.Alternative.Statements was not nil. got=%+v", exp.Alternative) + } +} + +func TestIfElseExpression(t *testing.T) { + input := `if (x < y) { x } else { y }` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain %d statements. got=%d\n", + 1, len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", + program.Statements[0]) + } + + exp, ok := stmt.Expression.(*ast.IfExpression) + if !ok { + t.Fatalf("stmt.Expression is not ast.IfExpression. got=%T", + stmt.Expression) + } + + if !testInfixExpression(t, exp.Condition, "x", "<", "y") { + return + } + + if len(exp.Consequence.Statements) != 1 { + t.Errorf("consequence is not 1 statements. got=%d\n", + len(exp.Consequence.Statements)) + } + + consequence, ok := exp.Consequence.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("Statements[0] of consequence is not ast.ExpressionStatement. got=%T", + exp.Consequence.Statements[0]) + } + + if !testIdentifier(t, consequence.Expression, "x") { + return + } + + if len(exp.Alternative.Statements) != 1 { + t.Errorf("alternative is not 1 statement. got=%d\n", + len(exp.Alternative.Statements)) + } + + alternative, ok := exp.Alternative.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("Statements[0] of alternative is not ast.ExpressionStatement. got=%T", + exp.Alternative.Statements[0]) + } + if !testIdentifier(t, alternative.Expression, "y") { + return + } +} func testIdentifier(t *testing.T, exp ast.Expression, value string) bool { ident, ok := exp.(*ast.Identifier) if !ok {