diff --git a/pkg/evaluator/evaluator.go b/pkg/evaluator/evaluator.go index c2c06c8..8898dfd 100644 --- a/pkg/evaluator/evaluator.go +++ b/pkg/evaluator/evaluator.go @@ -59,6 +59,20 @@ func Eval(node ast.Node, env *object.Environment) object.Object { env.Set(node.Name.Value, val) case *ast.Identifier: return evalIdentifier(node, env) + case *ast.FunctionLiteral: + params := node.Parameters + body := node.Body + return &object.Function{Parameters: params, Body: body, Env: env} + case *ast.CallExpression: + fn := Eval(node.Function, env) + if isError(fn) { + return fn + } + args := evalExpressions(node.Arguments, env) + if len(args) == 1 && isError(args[0]) { + return args[0] + } + return applyFunction(fn, args) } return nil } @@ -184,6 +198,46 @@ func evalIdentifier(exp *ast.Identifier, env *object.Environment) object.Object return val } +func evalExpressions( + exps []ast.Expression, + env *object.Environment, +) []object.Object { + var res []object.Object + for _, exp := range exps { + ev := Eval(exp, env) + if isError(ev) { + return []object.Object{ev} + } + res = append(res, ev) + } + return res +} + +func applyFunction(fnObj object.Object, args []object.Object) object.Object { + fn, ok := fnObj.(*object.Function) + if !ok { + return newError("not a function: %s", fn.Type()) + } + env := extendFunctionEnv(fn, args) + ret := Eval(fn.Body, env) + return unwrapReturnValue(ret) +} + +func extendFunctionEnv(fn *object.Function, args []object.Object) *object.Environment { + env := object.NewEnclosedEnvironment(fn.Env) + for pi, param := range fn.Parameters { + env.Set(param.Value, args[pi]) + } + return env +} + +func unwrapReturnValue(obj object.Object) object.Object { + if ret, ok := obj.(*object.ReturnValue); ok { + return ret.Value + } + return obj +} + func isTruthy(obj object.Object) bool { switch obj { case _TRUE: diff --git a/pkg/evaluator/evaluator_test.go b/pkg/evaluator/evaluator_test.go index ae953c5..d46828e 100644 --- a/pkg/evaluator/evaluator_test.go +++ b/pkg/evaluator/evaluator_test.go @@ -221,6 +221,49 @@ func TestLetStatements(t *testing.T) { } } +func TestFunctionObject(t *testing.T) { + input := "fn(x) { x + 2; };" + + evaluated := testEval(input) + fn, ok := evaluated.(*object.Function) + if !ok { + t.Fatalf("object is not Function. got=%T (%+v)", evaluated, evaluated) + } + + if len(fn.Parameters) != 1 { + t.Fatalf("function has wrong parameters. Parameters=%+v", + fn.Parameters) + } + + if fn.Parameters[0].String() != "x" { + t.Fatalf("parameter is not 'x'. got=%q", fn.Parameters[0]) + } + + expectedBody := "(x + 2)" + + if fn.Body.String() != expectedBody { + t.Fatalf("body is not %q. got=%q", expectedBody, fn.Body.String()) + } +} + +func TestFunctionApplication(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"let identity = fn(x) { x; }; identity(5);", 5}, + {"let identity = fn(x) { return x; }; identity(5);", 5}, + {"let double = fn(x) { x * 2; }; double(5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5, 5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5 + 5, add(5, 5));", 20}, + {"fn(x) { x; }(5)", 5}, + } + + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } +} + func testNullObject(t *testing.T, obj object.Object) bool { if obj != _NULL { t.Errorf("object is not NULL. got=%T (%+v)", obj, obj) diff --git a/pkg/object/environment.go b/pkg/object/environment.go index 9a06f43..04ace5e 100644 --- a/pkg/object/environment.go +++ b/pkg/object/environment.go @@ -4,12 +4,23 @@ func NewEnvironment() *Environment { return &Environment{store: map[string]Object{}} } +func NewEnclosedEnvironment(outer *Environment) *Environment { + return &Environment{ + store: map[string]Object{}, + outer: outer, + } +} + type Environment struct { store map[string]Object + outer *Environment } func (e *Environment) Get(name string) (Object, bool) { obj, ok := e.store[name] + if !ok && e.outer != nil { + obj, ok = e.outer.Get(name) + } return obj, ok } diff --git a/pkg/object/object.go b/pkg/object/object.go index f1446ad..bb88b69 100644 --- a/pkg/object/object.go +++ b/pkg/object/object.go @@ -1,6 +1,12 @@ package object -import "fmt" +import ( + "bytes" + "fmt" + "strings" + + "code.jmug.me/jmug/interpreter-in-go/pkg/ast" +) type ObjectType string @@ -10,6 +16,7 @@ const ( NULL_OBJ = "NULL" RETURN_VALUE_OBJ = "RETURN" ERROR_OBJ = "ERROR" + FUNCTION_OBJ = "FUNCTION" ) type Object interface { @@ -69,3 +76,24 @@ func (e *Error) Type() ObjectType { func (e *Error) Inspect() string { return "ERROR: " + e.Message } + +type Function struct { + Parameters []*ast.Identifier + Body *ast.BlockStatement + Env *Environment +} + +func (f *Function) Type() ObjectType { + return FUNCTION_OBJ +} +func (f *Function) Inspect() string { + var out bytes.Buffer + params := []string{} + for _, p := range f.Parameters { + params = append(params, p.Value) + } + out.WriteString("fn") + out.WriteString("(" + strings.Join(params, ", ") + ")") + out.WriteString(" {\n" + f.Body.String() + "\n}") + return out.String() +}