blob: 6a86a2e1dc3fd1ad0fcb1e1af799c4f2cbdf31ad [file] [log] [blame]
package repl
import (
"bufio"
"fmt"
"io"
"math"
"strings"
"matheval/ast"
"matheval/evaluator"
"matheval/lexer"
"matheval/parser"
)
const prompt = ">> "
// Run starts the read-eval-print loop, reading from r and writing to w.
func Run(r io.Reader, w io.Writer) {
scanner := bufio.NewScanner(r)
ev := evaluator.New()
fmt.Fprint(w, prompt)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
fmt.Fprint(w, prompt)
continue
}
tokens, err := lexer.Tokenize(line)
if err != nil {
fmt.Fprintf(w, "error: %s\n", err)
fmt.Fprint(w, prompt)
continue
}
stmt, err := parser.ParseLine(tokens)
if err != nil {
fmt.Fprintf(w, "error: %s\n", err)
fmt.Fprint(w, prompt)
continue
}
switch s := stmt.(type) {
case *ast.FuncDef:
if err := ev.Define(s); err != nil {
fmt.Fprintf(w, "error: %s\n", err)
} else {
fmt.Fprintf(w, "defined %s\n", s.Name)
}
case *ast.ExprStmt:
result, err := ev.Eval(s.Expr, nil)
if err != nil {
fmt.Fprintf(w, "error: %s\n", err)
} else {
fmt.Fprintln(w, formatResult(result))
}
}
fmt.Fprint(w, prompt)
}
}
// formatResult formats a float64 for display.
// Whole numbers are printed without decimal points.
func formatResult(val float64) string {
if val == math.Trunc(val) && !math.IsInf(val, 0) && !math.IsNaN(val) {
return fmt.Sprintf("%g", val)
}
return fmt.Sprintf("%g", val)
}