Step 5: Add Evaluator struct with env, Define, FuncCall support; backward-compat Eval preserved
diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go
index 411463c..cab1472 100644
--- a/evaluator/evaluator_test.go
+++ b/evaluator/evaluator_test.go
@@ -4,9 +4,12 @@
"math"
"matheval/ast"
"matheval/token"
+ "strings"
"testing"
)
+// --- Backward-compatible package-level Eval ---
+
func TestEvalNumberLit(t *testing.T) {
result, err := Eval(&ast.NumberLit{Value: 42.5})
if err != nil {
@@ -173,3 +176,311 @@
t.Fatalf("expected 3.8, got %v", result)
}
}
+
+// --- Evaluator struct: Ident ---
+
+func TestEvaluator_Ident(t *testing.T) {
+ ev := New()
+ env := map[string]float64{"x": 7}
+ result, err := ev.Eval(&ast.Ident{Name: "x"}, env)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 7 {
+ t.Fatalf("expected 7, got %v", result)
+ }
+}
+
+func TestEvaluator_IdentUndefined(t *testing.T) {
+ ev := New()
+ _, err := ev.Eval(&ast.Ident{Name: "x"}, nil)
+ if err == nil {
+ t.Fatal("expected error for undefined variable")
+ }
+ if !strings.Contains(err.Error(), "undefined variable") {
+ t.Errorf("expected 'undefined variable' in error, got: %v", err)
+ }
+}
+
+func TestEvaluator_IdentInExpr(t *testing.T) {
+ ev := New()
+ env := map[string]float64{"x": 3, "y": 4}
+ // x + y
+ node := &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.Ident{Name: "y"},
+ }
+ result, err := ev.Eval(node, env)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 7 {
+ t.Fatalf("expected 7, got %v", result)
+ }
+}
+
+// --- Evaluator struct: Define + FuncCall ---
+
+func TestEvaluator_DefineAndCall(t *testing.T) {
+ ev := New()
+ // f(x) = x + 1
+ err := ev.Define(&ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.NumberLit{Value: 1},
+ },
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // f(5) = 6
+ result, err := ev.Eval(&ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{&ast.NumberLit{Value: 5}},
+ }, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 6 {
+ t.Fatalf("expected 6, got %v", result)
+ }
+}
+
+func TestEvaluator_DefineMultiParam(t *testing.T) {
+ ev := New()
+ // add(x, y) = x + y
+ err := ev.Define(&ast.FuncDef{
+ Name: "add",
+ Params: []string{"x", "y"},
+ Body: &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.Ident{Name: "y"},
+ },
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // add(3, 4) = 7
+ result, err := ev.Eval(&ast.FuncCall{
+ Name: "add",
+ Args: []ast.Node{
+ &ast.NumberLit{Value: 3},
+ &ast.NumberLit{Value: 4},
+ },
+ }, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 7 {
+ t.Fatalf("expected 7, got %v", result)
+ }
+}
+
+func TestEvaluator_DefineRedefinitionError(t *testing.T) {
+ ev := New()
+ def := &ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.Ident{Name: "x"},
+ }
+ if err := ev.Define(def); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ err := ev.Define(def)
+ if err == nil {
+ t.Fatal("expected error for redefining function")
+ }
+ if !strings.Contains(err.Error(), "already defined") {
+ t.Errorf("expected 'already defined' in error, got: %v", err)
+ }
+}
+
+func TestEvaluator_UndefinedFunction(t *testing.T) {
+ ev := New()
+ _, err := ev.Eval(&ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{&ast.NumberLit{Value: 1}},
+ }, nil)
+ if err == nil {
+ t.Fatal("expected error for undefined function")
+ }
+ if !strings.Contains(err.Error(), "undefined function") {
+ t.Errorf("expected 'undefined function' in error, got: %v", err)
+ }
+}
+
+func TestEvaluator_WrongArgCount(t *testing.T) {
+ ev := New()
+ ev.Define(&ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.Ident{Name: "x"},
+ })
+ _, err := ev.Eval(&ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{
+ &ast.NumberLit{Value: 1},
+ &ast.NumberLit{Value: 2},
+ },
+ }, nil)
+ if err == nil {
+ t.Fatal("expected error for wrong argument count")
+ }
+ if !strings.Contains(err.Error(), "expects 1 arguments, got 2") {
+ t.Errorf("expected arg count error, got: %v", err)
+ }
+}
+
+func TestEvaluator_FuncCallInExpr(t *testing.T) {
+ ev := New()
+ // f(x) = x * 2
+ ev.Define(&ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.BinaryExpr{
+ Op: token.Star,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.NumberLit{Value: 2},
+ },
+ })
+ // f(3) + 1 = 7
+ node := &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{&ast.NumberLit{Value: 3}},
+ },
+ Right: &ast.NumberLit{Value: 1},
+ }
+ result, err := ev.Eval(node, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 7 {
+ t.Fatalf("expected 7, got %v", result)
+ }
+}
+
+func TestEvaluator_NestedFuncCall(t *testing.T) {
+ ev := New()
+ // f(x) = x + 1
+ ev.Define(&ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.NumberLit{Value: 1},
+ },
+ })
+ // f(f(1)) = f(2) = 3
+ node := &ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{
+ &ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{&ast.NumberLit{Value: 1}},
+ },
+ },
+ }
+ result, err := ev.Eval(node, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 3 {
+ t.Fatalf("expected 3, got %v", result)
+ }
+}
+
+func TestEvaluator_CrossFunctionCall(t *testing.T) {
+ ev := New()
+ // double(x) = x * 2
+ ev.Define(&ast.FuncDef{
+ Name: "double",
+ Params: []string{"x"},
+ Body: &ast.BinaryExpr{
+ Op: token.Star,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.NumberLit{Value: 2},
+ },
+ })
+ // quad(x) = double(double(x))
+ ev.Define(&ast.FuncDef{
+ Name: "quad",
+ Params: []string{"x"},
+ Body: &ast.FuncCall{
+ Name: "double",
+ Args: []ast.Node{
+ &ast.FuncCall{
+ Name: "double",
+ Args: []ast.Node{&ast.Ident{Name: "x"}},
+ },
+ },
+ },
+ })
+ // quad(3) = double(double(3)) = double(6) = 12
+ result, err := ev.Eval(&ast.FuncCall{
+ Name: "quad",
+ Args: []ast.Node{&ast.NumberLit{Value: 3}},
+ }, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 12 {
+ t.Fatalf("expected 12, got %v", result)
+ }
+}
+
+func TestEvaluator_FuncNoParams(t *testing.T) {
+ ev := New()
+ // c() = 42
+ ev.Define(&ast.FuncDef{
+ Name: "c",
+ Params: []string{},
+ Body: &ast.NumberLit{Value: 42},
+ })
+ result, err := ev.Eval(&ast.FuncCall{
+ Name: "c",
+ Args: []ast.Node{},
+ }, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != 42 {
+ t.Fatalf("expected 42, got %v", result)
+ }
+}
+
+func TestEvaluator_ArgEvaluatedInCallerEnv(t *testing.T) {
+ ev := New()
+ // f(x) = x + 1
+ ev.Define(&ast.FuncDef{
+ Name: "f",
+ Params: []string{"x"},
+ Body: &ast.BinaryExpr{
+ Op: token.Plus,
+ Left: &ast.Ident{Name: "x"},
+ Right: &ast.NumberLit{Value: 1},
+ },
+ })
+ // Evaluate f(y) with y=10 in caller env
+ callerEnv := map[string]float64{"y": 10}
+ result, err := ev.Eval(&ast.FuncCall{
+ Name: "f",
+ Args: []ast.Node{&ast.Ident{Name: "y"}},
+ }, callerEnv)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ // f(10) = 11
+ if result != 11 {
+ t.Fatalf("expected 11, got %v", result)
+ }
+}