Step 5: Add Evaluator struct with env, Define, FuncCall support; backward-compat Eval preserved
diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go
index 447adbe..e44fe7d 100644
--- a/evaluator/evaluator.go
+++ b/evaluator/evaluator.go
@@ -6,19 +6,47 @@
"matheval/token"
)
-// Eval evaluates an AST node and returns the result.
-// Returns an error on division by zero.
-func Eval(node ast.Node) (float64, error) {
+// Evaluator holds function definitions and evaluates AST nodes.
+type Evaluator struct {
+ funcs map[string]*ast.FuncDef
+}
+
+// New creates a new Evaluator with no defined functions.
+func New() *Evaluator {
+ return &Evaluator{funcs: make(map[string]*ast.FuncDef)}
+}
+
+// Define registers a function definition.
+// Returns an error if a function with the same name is already defined.
+func (ev *Evaluator) Define(def *ast.FuncDef) error {
+ if _, exists := ev.funcs[def.Name]; exists {
+ return fmt.Errorf("function %q already defined", def.Name)
+ }
+ ev.funcs[def.Name] = def
+ return nil
+}
+
+// Eval evaluates an AST node with the given variable environment.
+// env maps variable names to their values; nil is treated as empty.
+func (ev *Evaluator) Eval(node ast.Node, env map[string]float64) (float64, error) {
switch n := node.(type) {
case *ast.NumberLit:
return n.Value, nil
+ case *ast.Ident:
+ if env != nil {
+ if val, ok := env[n.Name]; ok {
+ return val, nil
+ }
+ }
+ return 0, fmt.Errorf("undefined variable %q", n.Name)
+
case *ast.BinaryExpr:
- left, err := Eval(n.Left)
+ left, err := ev.Eval(n.Left, env)
if err != nil {
return 0, err
}
- right, err := Eval(n.Right)
+ right, err := ev.Eval(n.Right, env)
if err != nil {
return 0, err
}
@@ -39,7 +67,33 @@
return 0, fmt.Errorf("unknown operator: %v", n.Op)
}
+ case *ast.FuncCall:
+ def, ok := ev.funcs[n.Name]
+ if !ok {
+ return 0, fmt.Errorf("undefined function %q", n.Name)
+ }
+ if len(n.Args) != len(def.Params) {
+ return 0, fmt.Errorf("function %q expects %d arguments, got %d", n.Name, len(def.Params), len(n.Args))
+ }
+ // Evaluate arguments in caller's environment.
+ newEnv := make(map[string]float64, len(def.Params))
+ for i, param := range def.Params {
+ val, err := ev.Eval(n.Args[i], env)
+ if err != nil {
+ return 0, err
+ }
+ newEnv[param] = val
+ }
+ // Evaluate function body in new environment.
+ return ev.Eval(def.Body, newEnv)
+
default:
return 0, fmt.Errorf("unknown node type: %T", node)
}
}
+
+// Eval is a backward-compatible package-level function.
+// It evaluates an AST node without any variable/function context.
+func Eval(node ast.Node) (float64, error) {
+ return New().Eval(node, nil)
+}
diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go
index 411463c..cab1472 100644
--- a/evaluator/evaluator_test.go
+++ b/evaluator/evaluator_test.go
@@ -4,9 +4,12 @@
"math"
"matheval/ast"
"matheval/token"
+ "strings"
"testing"
)
+// --- Backward-compatible package-level Eval ---
+
func TestEvalNumberLit(t *testing.T) {
result, err := Eval(&ast.NumberLit{Value: 42.5})
if err != nil {
@@ -173,3 +176,311 @@
t.Fatalf("expected 3.8, got %v", result)
}
}
+
+// --- Evaluator struct: Ident ---
+
+func TestEvaluator_Ident(t *testing.T) {
+ ev := New()
+ env := map[string]float64{"x": 7}
+ result, err := ev.Eval(&ast.Ident{Name: "x"}, env)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 7 {
+ t.Fatalf("expected 7, got %v", result)
+ }
+}
+
+func TestEvaluator_IdentUndefined(t *testing.T) {
+ ev := New()
+ _, err := ev.Eval(&ast.Ident{Name: "x"}, nil)
+ if err == nil {
+ t.Fatal("expected error for undefined variable")
+ }
+ if !strings.Contains(err.Error(), "undefined variable") {
+ t.Errorf("expected 'undefined variable' in error, got: %v", err)
+ }
+}
+
+func TestEvaluator_IdentInExpr(t *testing.T) {
+ ev := New()
+ env := map[string]float64{"x": 3, "y": 4}
+ // x + y
+ node := &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.Ident{Name: "y"},
+ }
+ result, err := ev.Eval(node, env)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 7 {
+ t.Fatalf("expected 7, got %v", result)
+ }
+}
+
+// --- Evaluator struct: Define + FuncCall ---
+
+func TestEvaluator_DefineAndCall(t *testing.T) {
+ ev := New()
+ // f(x) = x + 1
+ err := ev.Define(&ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.NumberLit{Value: 1},
+ },
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // f(5) = 6
+ result, err := ev.Eval(&ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{&ast.NumberLit{Value: 5}},
+ }, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 6 {
+ t.Fatalf("expected 6, got %v", result)
+ }
+}
+
+func TestEvaluator_DefineMultiParam(t *testing.T) {
+ ev := New()
+ // add(x, y) = x + y
+ err := ev.Define(&ast.FuncDef{
+ Name: "add",
+ Params: []string{"x", "y"},
+ Body: &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.Ident{Name: "y"},
+ },
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // add(3, 4) = 7
+ result, err := ev.Eval(&ast.FuncCall{
+ Name: "add",
+ Args: []ast.Node{
+ &ast.NumberLit{Value: 3},
+ &ast.NumberLit{Value: 4},
+ },
+ }, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 7 {
+ t.Fatalf("expected 7, got %v", result)
+ }
+}
+
+func TestEvaluator_DefineRedefinitionError(t *testing.T) {
+ ev := New()
+ def := &ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.Ident{Name: "x"},
+ }
+ if err := ev.Define(def); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ err := ev.Define(def)
+ if err == nil {
+ t.Fatal("expected error for redefining function")
+ }
+ if !strings.Contains(err.Error(), "already defined") {
+ t.Errorf("expected 'already defined' in error, got: %v", err)
+ }
+}
+
+func TestEvaluator_UndefinedFunction(t *testing.T) {
+ ev := New()
+ _, err := ev.Eval(&ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{&ast.NumberLit{Value: 1}},
+ }, nil)
+ if err == nil {
+ t.Fatal("expected error for undefined function")
+ }
+ if !strings.Contains(err.Error(), "undefined function") {
+ t.Errorf("expected 'undefined function' in error, got: %v", err)
+ }
+}
+
+func TestEvaluator_WrongArgCount(t *testing.T) {
+ ev := New()
+ ev.Define(&ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.Ident{Name: "x"},
+ })
+ _, err := ev.Eval(&ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{
+ &ast.NumberLit{Value: 1},
+ &ast.NumberLit{Value: 2},
+ },
+ }, nil)
+ if err == nil {
+ t.Fatal("expected error for wrong argument count")
+ }
+ if !strings.Contains(err.Error(), "expects 1 arguments, got 2") {
+ t.Errorf("expected arg count error, got: %v", err)
+ }
+}
+
+func TestEvaluator_FuncCallInExpr(t *testing.T) {
+ ev := New()
+ // f(x) = x * 2
+ ev.Define(&ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.BinaryExpr{
+ Op: token.Star,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.NumberLit{Value: 2},
+ },
+ })
+ // f(3) + 1 = 7
+ node := &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{&ast.NumberLit{Value: 3}},
+ },
+ Right: &ast.NumberLit{Value: 1},
+ }
+ result, err := ev.Eval(node, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 7 {
+ t.Fatalf("expected 7, got %v", result)
+ }
+}
+
+func TestEvaluator_NestedFuncCall(t *testing.T) {
+ ev := New()
+ // f(x) = x + 1
+ ev.Define(&ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.NumberLit{Value: 1},
+ },
+ })
+ // f(f(1)) = f(2) = 3
+ node := &ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{
+ &ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{&ast.NumberLit{Value: 1}},
+ },
+ },
+ }
+ result, err := ev.Eval(node, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 3 {
+ t.Fatalf("expected 3, got %v", result)
+ }
+}
+
+func TestEvaluator_CrossFunctionCall(t *testing.T) {
+ ev := New()
+ // double(x) = x * 2
+ ev.Define(&ast.FuncDef{
+ Name: "double",
+ Params: []string{"x"},
+ Body: &ast.BinaryExpr{
+ Op: token.Star,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.NumberLit{Value: 2},
+ },
+ })
+ // quad(x) = double(double(x))
+ ev.Define(&ast.FuncDef{
+ Name: "quad",
+ Params: []string{"x"},
+ Body: &ast.FuncCall{
+ Name: "double",
+ Args: []ast.Node{
+ &ast.FuncCall{
+ Name: "double",
+ Args: []ast.Node{&ast.Ident{Name: "x"}},
+ },
+ },
+ },
+ })
+ // quad(3) = double(double(3)) = double(6) = 12
+ result, err := ev.Eval(&ast.FuncCall{
+ Name: "quad",
+ Args: []ast.Node{&ast.NumberLit{Value: 3}},
+ }, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 12 {
+ t.Fatalf("expected 12, got %v", result)
+ }
+}
+
+func TestEvaluator_FuncNoParams(t *testing.T) {
+ ev := New()
+ // c() = 42
+ ev.Define(&ast.FuncDef{
+ Name: "c",
+ Params: []string{},
+ Body: &ast.NumberLit{Value: 42},
+ })
+ result, err := ev.Eval(&ast.FuncCall{
+ Name: "c",
+ Args: []ast.Node{},
+ }, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 42 {
+ t.Fatalf("expected 42, got %v", result)
+ }
+}
+
+func TestEvaluator_ArgEvaluatedInCallerEnv(t *testing.T) {
+ ev := New()
+ // f(x) = x + 1
+ ev.Define(&ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.NumberLit{Value: 1},
+ },
+ })
+ // Evaluate f(y) with y=10 in caller env
+ callerEnv := map[string]float64{"y": 10}
+ result, err := ev.Eval(&ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{&ast.Ident{Name: "y"}},
+ }, callerEnv)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ // f(10) = 11
+ if result != 11 {
+ t.Fatalf("expected 11, got %v", result)
+ }
+}
diff --git a/repl/repl.go b/repl/repl.go
index bc14567..6a86a2e 100644
--- a/repl/repl.go
+++ b/repl/repl.go
@@ -7,6 +7,7 @@
"math"
"strings"
+ "matheval/ast"
"matheval/evaluator"
"matheval/lexer"
"matheval/parser"
@@ -17,6 +18,7 @@
// Run starts the read-eval-print loop, reading from r and writing to w.
func Run(r io.Reader, w io.Writer) {
scanner := bufio.NewScanner(r)
+ ev := evaluator.New()
fmt.Fprint(w, prompt)
for scanner.Scan() {
@@ -26,32 +28,40 @@
continue
}
- result, err := evalLine(line)
+ tokens, err := lexer.Tokenize(line)
if err != nil {
fmt.Fprintf(w, "error: %s\n", err)
- } else {
- fmt.Fprintln(w, formatResult(result))
+ fmt.Fprint(w, prompt)
+ continue
+ }
+
+ stmt, err := parser.ParseLine(tokens)
+ if err != nil {
+ fmt.Fprintf(w, "error: %s\n", err)
+ fmt.Fprint(w, prompt)
+ continue
+ }
+
+ switch s := stmt.(type) {
+ case *ast.FuncDef:
+ if err := ev.Define(s); err != nil {
+ fmt.Fprintf(w, "error: %s\n", err)
+ } else {
+ fmt.Fprintf(w, "defined %s\n", s.Name)
+ }
+ case *ast.ExprStmt:
+ result, err := ev.Eval(s.Expr, nil)
+ if err != nil {
+ fmt.Fprintf(w, "error: %s\n", err)
+ } else {
+ fmt.Fprintln(w, formatResult(result))
+ }
}
fmt.Fprint(w, prompt)
}
}
-// evalLine tokenizes, parses, and evaluates a single expression string.
-func evalLine(line string) (float64, error) {
- tokens, err := lexer.Tokenize(line)
- if err != nil {
- return 0, err
- }
-
- tree, err := parser.Parse(tokens)
- if err != nil {
- return 0, err
- }
-
- return evaluator.Eval(tree)
-}
-
// formatResult formats a float64 for display.
// Whole numbers are printed without decimal points.
func formatResult(val float64) string {