blob: e44fe7dc38b77c0830c1fa796842b64ee8fb81e9 [file] [log] [blame]
package evaluator
import (
"fmt"
"matheval/ast"
"matheval/token"
)
// Evaluator holds function definitions and evaluates AST nodes.
type Evaluator struct {
funcs map[string]*ast.FuncDef
}
// New creates a new Evaluator with no defined functions.
func New() *Evaluator {
return &Evaluator{funcs: make(map[string]*ast.FuncDef)}
}
// Define registers a function definition.
// Returns an error if a function with the same name is already defined.
func (ev *Evaluator) Define(def *ast.FuncDef) error {
if _, exists := ev.funcs[def.Name]; exists {
return fmt.Errorf("function %q already defined", def.Name)
}
ev.funcs[def.Name] = def
return nil
}
// Eval evaluates an AST node with the given variable environment.
// env maps variable names to their values; nil is treated as empty.
func (ev *Evaluator) Eval(node ast.Node, env map[string]float64) (float64, error) {
switch n := node.(type) {
case *ast.NumberLit:
return n.Value, nil
case *ast.Ident:
if env != nil {
if val, ok := env[n.Name]; ok {
return val, nil
}
}
return 0, fmt.Errorf("undefined variable %q", n.Name)
case *ast.BinaryExpr:
left, err := ev.Eval(n.Left, env)
if err != nil {
return 0, err
}
right, err := ev.Eval(n.Right, env)
if err != nil {
return 0, err
}
switch n.Op {
case token.Plus:
return left + right, nil
case token.Minus:
return left - right, nil
case token.Star:
return left * right, nil
case token.Slash:
if right == 0 {
return 0, fmt.Errorf("division by zero")
}
return left / right, nil
default:
return 0, fmt.Errorf("unknown operator: %v", n.Op)
}
case *ast.FuncCall:
def, ok := ev.funcs[n.Name]
if !ok {
return 0, fmt.Errorf("undefined function %q", n.Name)
}
if len(n.Args) != len(def.Params) {
return 0, fmt.Errorf("function %q expects %d arguments, got %d", n.Name, len(def.Params), len(n.Args))
}
// Evaluate arguments in caller's environment.
newEnv := make(map[string]float64, len(def.Params))
for i, param := range def.Params {
val, err := ev.Eval(n.Args[i], env)
if err != nil {
return 0, err
}
newEnv[param] = val
}
// Evaluate function body in new environment.
return ev.Eval(def.Body, newEnv)
default:
return 0, fmt.Errorf("unknown node type: %T", node)
}
}
// Eval is a backward-compatible package-level function.
// It evaluates an AST node without any variable/function context.
func Eval(node ast.Node) (float64, error) {
return New().Eval(node, nil)
}