blob: 15b4ceae15030bf7412307f666e2b7b3f20fb606 [file] [log] [blame]
package parser
import (
"fmt"
"matheval/ast"
"matheval/token"
"strconv"
)
// Parse converts a slice of tokens into an AST.
// Returns an error for malformed expressions.
func Parse(tokens []token.Token) (ast.Node, error) {
p := &parser{tokens: tokens}
node, err := p.expr()
if err != nil {
return nil, err
}
// Ensure all tokens have been consumed (only EOF remains).
if p.current().Type != token.EOF {
tok := p.current()
return nil, fmt.Errorf("unexpected token %v at position %d", tok.Type, tok.Pos)
}
return node, nil
}
// ParseLine parses a full REPL line, returning either a function definition
// or an expression statement.
func ParseLine(tokens []token.Token) (ast.Statement, error) {
if len(tokens) == 0 || tokens[0].Type == token.EOF {
return nil, fmt.Errorf("empty input")
}
// Detect function definition: look for Equals token.
// A function definition has the form: Ident LParen params RParen Equals body
if isFuncDef(tokens) {
return parseFuncDef(tokens)
}
// Otherwise, parse as expression.
node, err := Parse(tokens)
if err != nil {
return nil, err
}
return &ast.ExprStmt{Expr: node}, nil
}
// isFuncDef checks if the token stream looks like a function definition.
// Pattern: Ident LParen ... RParen Equals ...
func isFuncDef(tokens []token.Token) bool {
if len(tokens) < 5 {
return false
}
if tokens[0].Type != token.Ident {
return false
}
if tokens[1].Type != token.LParen {
return false
}
// Find matching RParen, then check for Equals.
depth := 0
for i := 1; i < len(tokens); i++ {
switch tokens[i].Type {
case token.LParen:
depth++
case token.RParen:
depth--
if depth == 0 {
// Next token must be Equals for this to be a func def.
if i+1 < len(tokens) && tokens[i+1].Type == token.Equals {
return true
}
return false
}
case token.EOF:
return false
}
}
return false
}
// parseFuncDef parses: Ident LParen param1, param2, ... RParen Equals body
func parseFuncDef(tokens []token.Token) (*ast.FuncDef, error) {
p := &parser{tokens: tokens}
// Function name.
nameTok, err := p.expect(token.Ident)
if err != nil {
return nil, fmt.Errorf("expected function name: %w", err)
}
name := nameTok.Literal
// Opening paren.
if _, err := p.expect(token.LParen); err != nil {
return nil, fmt.Errorf("expected '(' after function name: %w", err)
}
// Parameters: comma-separated identifiers.
var params []string
if p.current().Type != token.RParen {
paramTok, err := p.expect(token.Ident)
if err != nil {
return nil, fmt.Errorf("expected parameter name: %w", err)
}
params = append(params, paramTok.Literal)
for p.current().Type == token.Comma {
p.advance() // consume comma
paramTok, err := p.expect(token.Ident)
if err != nil {
return nil, fmt.Errorf("expected parameter name after ',': %w", err)
}
params = append(params, paramTok.Literal)
}
}
// Closing paren.
if _, err := p.expect(token.RParen); err != nil {
return nil, fmt.Errorf("expected ')' after parameters: %w", err)
}
// Equals sign.
if _, err := p.expect(token.Equals); err != nil {
return nil, fmt.Errorf("expected '=' in function definition: %w", err)
}
// Body expression.
body, err := p.expr()
if err != nil {
return nil, fmt.Errorf("error in function body: %w", err)
}
// Ensure all tokens consumed.
if p.current().Type != token.EOF {
tok := p.current()
return nil, fmt.Errorf("unexpected token %v at position %d after function body", tok.Type, tok.Pos)
}
return &ast.FuncDef{
Name: name,
Params: params,
Body: body,
}, nil
}
// parser holds the state for a single parse operation.
type parser struct {
tokens []token.Token
pos int
}
// current returns the token at the current position.
func (p *parser) current() token.Token {
if p.pos >= len(p.tokens) {
return token.Token{Type: token.EOF}
}
return p.tokens[p.pos]
}
// advance moves to the next token and returns the previous one.
func (p *parser) advance() token.Token {
tok := p.current()
p.pos++
return tok
}
// expect consumes a token of the given type or returns an error.
func (p *parser) expect(typ token.Type) (token.Token, error) {
tok := p.current()
if tok.Type != typ {
return tok, fmt.Errorf("expected %v but got %v at position %d", typ, tok.Type, tok.Pos)
}
p.advance()
return tok, nil
}
// expr → term (('+' | '-') term)*
func (p *parser) expr() (ast.Node, error) {
left, err := p.term()
if err != nil {
return nil, err
}
for p.current().Type == token.Plus || p.current().Type == token.Minus {
op := p.advance()
right, err := p.term()
if err != nil {
return nil, err
}
left = &ast.BinaryExpr{
Op: op.Type,
Left: left,
Right: right,
}
}
return left, nil
}
// term → factor (('*' | '/') factor)*
func (p *parser) term() (ast.Node, error) {
left, err := p.factor()
if err != nil {
return nil, err
}
for p.current().Type == token.Star || p.current().Type == token.Slash {
op := p.advance()
right, err := p.factor()
if err != nil {
return nil, err
}
left = &ast.BinaryExpr{
Op: op.Type,
Left: left,
Right: right,
}
}
return left, nil
}
// factor → NUMBER | IDENT | IDENT '(' args ')' | '(' expr ')'
func (p *parser) factor() (ast.Node, error) {
tok := p.current()
switch tok.Type {
case token.Number:
p.advance()
val, err := strconv.ParseFloat(tok.Literal, 64)
if err != nil {
return nil, fmt.Errorf("invalid number %q at position %d: %w", tok.Literal, tok.Pos, err)
}
return &ast.NumberLit{Value: val}, nil
case token.Ident:
p.advance()
// If followed by '(', this is a function call.
if p.current().Type == token.LParen {
p.advance() // consume '('
var args []ast.Node
if p.current().Type != token.RParen {
arg, err := p.expr()
if err != nil {
return nil, err
}
args = append(args, arg)
for p.current().Type == token.Comma {
p.advance() // consume ','
arg, err := p.expr()
if err != nil {
return nil, err
}
args = append(args, arg)
}
}
if _, err := p.expect(token.RParen); err != nil {
return nil, fmt.Errorf("expected ')' after function arguments at position %d", p.current().Pos)
}
return &ast.FuncCall{Name: tok.Literal, Args: args}, nil
}
// Otherwise, it's a variable reference.
return &ast.Ident{Name: tok.Literal}, nil
case token.LParen:
p.advance() // consume '('
node, err := p.expr()
if err != nil {
return nil, err
}
if _, err := p.expect(token.RParen); err != nil {
return nil, fmt.Errorf("missing closing parenthesis at position %d", p.current().Pos)
}
return node, nil
default:
return nil, fmt.Errorf("unexpected token %v at position %d", tok.Type, tok.Pos)
}
}