From 8514ead895129638cb14ea055d7f057b769da8b9 Mon Sep 17 00:00:00 2001 From: jmug Date: Tue, 7 Jan 2025 18:15:12 -0800 Subject: [PATCH] Evaluate if and return, error validation. Signed-off-by: jmug --- pkg/evaluator/evaluator.go | 84 +++++++++++++++++-- pkg/evaluator/evaluator_test.go | 139 ++++++++++++++++++++++++++++++++ pkg/object/object.go | 30 ++++++- 3 files changed, 244 insertions(+), 9 deletions(-) diff --git a/pkg/evaluator/evaluator.go b/pkg/evaluator/evaluator.go index 859af83..a242b39 100644 --- a/pkg/evaluator/evaluator.go +++ b/pkg/evaluator/evaluator.go @@ -1,6 +1,8 @@ package evaluator import ( + "fmt" + "code.jmug.me/jmug/interpreter-in-go/pkg/ast" "code.jmug.me/jmug/interpreter-in-go/pkg/object" ) @@ -15,7 +17,7 @@ func Eval(node ast.Node) object.Object { switch node := node.(type) { // Statements. case *ast.Program: - return evalStatements(node.Statements) + return evalProgram(node.Statements) case *ast.ExpressionStatement: return Eval(node.Expression) // Expressions. @@ -30,14 +32,37 @@ func Eval(node ast.Node) object.Object { left := Eval(node.Left) right := Eval(node.Right) return evalInfixExpression(node.Operator, left, right) + case *ast.BlockStatement: + return evalBlockStatement(node.Statements) + case *ast.IfExpression: + return evalIfExpression(node) + case *ast.ReturnStatement: + return &object.ReturnValue{Value: Eval(node.ReturnValue)} } return nil } -func evalStatements(stmts []ast.Statement) object.Object { +func evalProgram(stmts []ast.Statement) object.Object { var res object.Object for _, stmt := range stmts { res = Eval(stmt) + switch res := res.(type) { + case *object.ReturnValue: + return res.Value + case *object.Error: + return res + } + } + return res +} + +func evalBlockStatement(stmts []ast.Statement) object.Object { + var res object.Object + for _, stmt := range stmts { + res = Eval(stmt) + if res != nil && res.Type() == object.RETURN_VALUE_OBJ { + return res + } } return res } @@ -48,8 +73,9 @@ func evalPrefixExpression(op string, right object.Object) object.Object { return evalBangOperatorExpression(right) case "-": return evalMinusPrefixOperatorExpression(right) + default: + return newError("unknown operator: %s%s", op, right.Type()) } - return _NULL } func evalBangOperatorExpression(obj object.Object) object.Object { @@ -67,17 +93,27 @@ func evalBangOperatorExpression(obj object.Object) object.Object { func evalMinusPrefixOperatorExpression(obj object.Object) object.Object { if obj.Type() != object.INTEGER_OBJ { - return _NULL + return newError("unknown operator: -%s", obj.Type()) } val := obj.(*object.Integer).Value return &object.Integer{Value: -val} } func evalInfixExpression(op string, left, right object.Object) object.Object { - if left.Type() == object.INTEGER_OBJ && right.Type() == object.INTEGER_OBJ { + switch { + case left.Type() == object.INTEGER_OBJ && right.Type() == object.INTEGER_OBJ: return evalIntegerInfixExpression(op, left, right) + case op == "==": + return nativeBoolToBooleanObject(left == right) + case op == "!=": + return nativeBoolToBooleanObject(left != right) + case left.Type() != right.Type(): + return newError("type mismatch: %s %s %s", + left.Type(), op, right.Type()) + default: + return newError("unknown operator: %s %s %s", + left.Type(), op, right.Type()) } - return _NULL } func evalIntegerInfixExpression(op string, left, right object.Object) object.Object { @@ -92,13 +128,49 @@ func evalIntegerInfixExpression(op string, left, right object.Object) object.Obj return &object.Integer{Value: l * r} case "/": return &object.Integer{Value: l / r} + case "<": + return nativeBoolToBooleanObject(l < r) + case ">": + return nativeBoolToBooleanObject(l > r) + case "==": + return nativeBoolToBooleanObject(l == r) + case "!=": + return nativeBoolToBooleanObject(l != r) + default: + return newError("unknown operator: %s %s %s", + left.Type(), op, right.Type()) + } +} + +func evalIfExpression(ifExp *ast.IfExpression) object.Object { + cond := Eval(ifExp.Condition) + if isTruthy(cond) { + return Eval(ifExp.Consequence) + } else if ifExp.Alternative != nil { + return Eval(ifExp.Alternative) } return _NULL } +func isTruthy(obj object.Object) bool { + switch obj { + case _TRUE: + return true + case _FALSE: + return false + case _NULL: + return false + } + return true +} + func nativeBoolToBooleanObject(b bool) object.Object { if b { return _TRUE } return _FALSE } + +func newError(format string, a ...any) *object.Error { + return &object.Error{Message: fmt.Sprintf(format, a...)} +} diff --git a/pkg/evaluator/evaluator_test.go b/pkg/evaluator/evaluator_test.go index 7b11567..978441c 100644 --- a/pkg/evaluator/evaluator_test.go +++ b/pkg/evaluator/evaluator_test.go @@ -43,6 +43,14 @@ func TestEvalBooleanExpression(t *testing.T) { }{ {"true", true}, {"false", false}, + {"1 < 2", true}, + {"1 > 2", false}, + {"1 < 1", false}, + {"1 > 1", false}, + {"1 == 1", true}, + {"1 != 1", false}, + {"1 == 2", false}, + {"1 != 2", true}, } for _, tt := range tests { @@ -62,6 +70,15 @@ func TestBangOperator(t *testing.T) { {"!!true", true}, {"!!false", false}, {"!!5", true}, + {"true == true", true}, + {"false == false", true}, + {"true == false", false}, + {"true != false", true}, + {"false != true", true}, + {"(1 < 2) == true", true}, + {"(1 < 2) == false", false}, + {"(1 > 2) == true", false}, + {"(1 > 2) == false", true}, } for _, tt := range tests { @@ -70,6 +87,128 @@ func TestBangOperator(t *testing.T) { } } +func TestIfElseExpressions(t *testing.T) { + tests := []struct { + input string + expected interface{} + }{ + {"if (true) { 10 }", 10}, + {"if (false) { 10 }", nil}, + {"if (1) { 10 }", 10}, + {"if (1 < 2) { 10 }", 10}, + {"if (1 > 2) { 10 }", nil}, + {"if (1 > 2) { 10 } else { 20 }", 20}, + {"if (1 < 2) { 10 } else { 20 }", 10}, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + integer, ok := tt.expected.(int) + if ok { + testIntegerObject(t, evaluated, int64(integer)) + } else { + testNullObject(t, evaluated) + } + } +} + +func TestReturnStatements(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"return 10;", 10}, + {"return 10; 9;", 10}, + {"return 2 * 5; 9;", 10}, + {"9; return 2 * 5; 9;", 10}, + { + ` +if (10 > 1) { + if (10 > 1) { + return 10; + } + + return 1; +} +`, + 10, + }, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + testIntegerObject(t, evaluated, tt.expected) + } +} + +func TestErrorHandling(t *testing.T) { + tests := []struct { + input string + expectedMessage string + }{ + { + "5 + true;", + "type mismatch: INTEGER + BOOLEAN", + }, + { + "5 + true; 5;", + "type mismatch: INTEGER + BOOLEAN", + }, + { + "-true", + "unknown operator: -BOOLEAN", + }, + { + "true + false;", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + "5; true + false; 5", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + "if (10 > 1) { true + false; }", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + ` +if (10 > 1) { + if (10 > 1) { + return true + false; + } + + return 1; +} +`, + "unknown operator: BOOLEAN + BOOLEAN", + }, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + + errObj, ok := evaluated.(*object.Error) + if !ok { + t.Errorf("no error object returned. got=%T(%+v)", + evaluated, evaluated) + continue + } + + if errObj.Message != tt.expectedMessage { + t.Errorf("wrong error message. expected=%q, got=%q", + tt.expectedMessage, errObj.Message) + } + } +} + +func testNullObject(t *testing.T, obj object.Object) bool { + if obj != _NULL { + t.Errorf("object is not NULL. got=%T (%+v)", obj, obj) + return false + } + return true +} + func testEval(input string) object.Object { l := lexer.New(input) p := parser.New(l) diff --git a/pkg/object/object.go b/pkg/object/object.go index 5f3d2b7..f1446ad 100644 --- a/pkg/object/object.go +++ b/pkg/object/object.go @@ -5,9 +5,11 @@ import "fmt" type ObjectType string const ( - INTEGER_OBJ = "INTEGER" - BOOLEAN_OBJ = "BOOLEAN" - NULL_OBJ = "NULL" + INTEGER_OBJ = "INTEGER" + BOOLEAN_OBJ = "BOOLEAN" + NULL_OBJ = "NULL" + RETURN_VALUE_OBJ = "RETURN" + ERROR_OBJ = "ERROR" ) type Object interface { @@ -45,3 +47,25 @@ func (n *Null) Type() ObjectType { func (n *Null) Inspect() string { return "null" } + +type ReturnValue struct { + Value Object +} + +func (rv *ReturnValue) Type() ObjectType { + return RETURN_VALUE_OBJ +} +func (rv *ReturnValue) Inspect() string { + return rv.Value.Inspect() +} + +type Error struct { + Message string +} + +func (e *Error) Type() ObjectType { + return ERROR_OBJ +} +func (e *Error) Inspect() string { + return "ERROR: " + e.Message +}