From 500a058ff8cfa7f926863177c109be47e4d3dd26 Mon Sep 17 00:00:00 2001 From: jmug Date: Wed, 8 Jan 2025 19:41:36 -0800 Subject: [PATCH] Add variable bindings and references to the evaluator. Signed-off-by: jmug --- pkg/evaluator/evaluator.go | 72 ++++++++++++++++++++++++--------- pkg/evaluator/evaluator_test.go | 22 +++++++++- pkg/object/environment.go | 19 +++++++++ pkg/repl/repl.go | 4 +- 4 files changed, 97 insertions(+), 20 deletions(-) create mode 100644 pkg/object/environment.go diff --git a/pkg/evaluator/evaluator.go b/pkg/evaluator/evaluator.go index a242b39..c2c06c8 100644 --- a/pkg/evaluator/evaluator.go +++ b/pkg/evaluator/evaluator.go @@ -13,39 +13,60 @@ var ( _FALSE = &object.Boolean{Value: false} ) -func Eval(node ast.Node) object.Object { +func Eval(node ast.Node, env *object.Environment) object.Object { switch node := node.(type) { // Statements. case *ast.Program: - return evalProgram(node.Statements) + return evalProgram(node.Statements, env) case *ast.ExpressionStatement: - return Eval(node.Expression) + return Eval(node.Expression, env) // Expressions. case *ast.IntegerLiteral: return &object.Integer{Value: node.Value} case *ast.Boolean: return nativeBoolToBooleanObject(node.Value) case *ast.PrefixExpression: - right := Eval(node.Right) + right := Eval(node.Right, env) + if isError(right) { + return right + } return evalPrefixExpression(node.Operator, right) case *ast.InfixExpression: - left := Eval(node.Left) - right := Eval(node.Right) + left := Eval(node.Left, env) + if isError(left) { + return left + } + right := Eval(node.Right, env) + if isError(right) { + return right + } return evalInfixExpression(node.Operator, left, right) case *ast.BlockStatement: - return evalBlockStatement(node.Statements) + return evalBlockStatement(node.Statements, env) case *ast.IfExpression: - return evalIfExpression(node) + return evalIfExpression(node, env) case *ast.ReturnStatement: - return &object.ReturnValue{Value: Eval(node.ReturnValue)} + ret := Eval(node.ReturnValue, env) + if isError(ret) { + return ret + } + return &object.ReturnValue{Value: ret} + case *ast.LetStatement: + val := Eval(node.Value, env) + if isError(val) { + return val + } + env.Set(node.Name.Value, val) + case *ast.Identifier: + return evalIdentifier(node, env) } return nil } -func evalProgram(stmts []ast.Statement) object.Object { +func evalProgram(stmts []ast.Statement, env *object.Environment) object.Object { var res object.Object for _, stmt := range stmts { - res = Eval(stmt) + res = Eval(stmt, env) switch res := res.(type) { case *object.ReturnValue: return res.Value @@ -56,11 +77,11 @@ func evalProgram(stmts []ast.Statement) object.Object { return res } -func evalBlockStatement(stmts []ast.Statement) object.Object { +func evalBlockStatement(stmts []ast.Statement, env *object.Environment) object.Object { var res object.Object for _, stmt := range stmts { - res = Eval(stmt) - if res != nil && res.Type() == object.RETURN_VALUE_OBJ { + res = Eval(stmt, env) + if res != nil && (res.Type() == object.RETURN_VALUE_OBJ || res.Type() == object.ERROR_OBJ) { return res } } @@ -142,16 +163,27 @@ func evalIntegerInfixExpression(op string, left, right object.Object) object.Obj } } -func evalIfExpression(ifExp *ast.IfExpression) object.Object { - cond := Eval(ifExp.Condition) +func evalIfExpression(ifExp *ast.IfExpression, env *object.Environment) object.Object { + cond := Eval(ifExp.Condition, env) + if isError(cond) { + return cond + } if isTruthy(cond) { - return Eval(ifExp.Consequence) + return Eval(ifExp.Consequence, env) } else if ifExp.Alternative != nil { - return Eval(ifExp.Alternative) + return Eval(ifExp.Alternative, env) } return _NULL } +func evalIdentifier(exp *ast.Identifier, env *object.Environment) object.Object { + val, ok := env.Get(exp.Value) + if !ok { + return newError("identifier not found: " + exp.Value) + } + return val +} + func isTruthy(obj object.Object) bool { switch obj { case _TRUE: @@ -174,3 +206,7 @@ func nativeBoolToBooleanObject(b bool) object.Object { func newError(format string, a ...any) *object.Error { return &object.Error{Message: fmt.Sprintf(format, a...)} } + +func isError(obj object.Object) bool { + return obj != nil && obj.Type() == object.ERROR_OBJ +} diff --git a/pkg/evaluator/evaluator_test.go b/pkg/evaluator/evaluator_test.go index 978441c..ae953c5 100644 --- a/pkg/evaluator/evaluator_test.go +++ b/pkg/evaluator/evaluator_test.go @@ -182,6 +182,10 @@ if (10 > 1) { `, "unknown operator: BOOLEAN + BOOLEAN", }, + { + "foobar", + "identifier not found: foobar", + }, } for _, tt := range tests { @@ -201,6 +205,22 @@ if (10 > 1) { } } +func TestLetStatements(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"let a = 5; a;", 5}, + {"let a = 5 * 5; a;", 25}, + {"let a = 5; let b = a; b;", 5}, + {"let a = 5; let b = a; let c = a + b + 5; c;", 15}, + } + + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } +} + func testNullObject(t *testing.T, obj object.Object) bool { if obj != _NULL { t.Errorf("object is not NULL. got=%T (%+v)", obj, obj) @@ -214,7 +234,7 @@ func testEval(input string) object.Object { p := parser.New(l) program := p.ParseProgram() - return Eval(program) + return Eval(program, object.NewEnvironment()) } func testIntegerObject(t *testing.T, obj object.Object, expected int64) bool { diff --git a/pkg/object/environment.go b/pkg/object/environment.go new file mode 100644 index 0000000..9a06f43 --- /dev/null +++ b/pkg/object/environment.go @@ -0,0 +1,19 @@ +package object + +func NewEnvironment() *Environment { + return &Environment{store: map[string]Object{}} +} + +type Environment struct { + store map[string]Object +} + +func (e *Environment) Get(name string) (Object, bool) { + obj, ok := e.store[name] + return obj, ok +} + +func (e *Environment) Set(name string, obj Object) Object { + e.store[name] = obj + return obj +} diff --git a/pkg/repl/repl.go b/pkg/repl/repl.go index e311a7d..0878839 100644 --- a/pkg/repl/repl.go +++ b/pkg/repl/repl.go @@ -7,6 +7,7 @@ import ( "code.jmug.me/jmug/interpreter-in-go/pkg/evaluator" "code.jmug.me/jmug/interpreter-in-go/pkg/lexer" + "code.jmug.me/jmug/interpreter-in-go/pkg/object" "code.jmug.me/jmug/interpreter-in-go/pkg/parser" ) @@ -14,6 +15,7 @@ const PROMPT = ">> " func Start(in io.Reader, out io.Writer) { scanner := bufio.NewScanner(in) + env := object.NewEnvironment() for { fmt.Fprint(out, PROMPT) if !scanner.Scan() { @@ -26,7 +28,7 @@ func Start(in io.Reader, out io.Writer) { printParserErrors(out, p.Errors()) continue } - res := evaluator.Eval(program) + res := evaluator.Eval(program, env) if res != nil { io.WriteString(out, res.Inspect()) io.WriteString(out, "\n")