package eval // TODO: a lot of code duplication that can be solved by a primitive registry // TODO: a lot of code duplication that can be solved by a type checking helper import ( "fmt" "github.com/hellerve/argos/ast" ) var trueVal = ast.AST{ast.Bool, true} var falseVal = ast.AST{ast.Bool, false} var nilVal = ast.AST{ast.List, []ast.AST{}} type env struct { parent *env values map[string]*ast.AST } func newEnv(parent *env) env { return env{parent, make(map[string]*ast.AST)} } func ParentEnv() env { return newEnv(nil) } func (e env) Lookup(elem string) (*ast.AST, error) { res, ok := e.values[elem] if !ok { if e.parent == nil { return nil, fmt.Errorf("Symbol not found: %s", elem) } return e.parent.Lookup(elem) } return res, nil } func checkArity(input []ast.AST, arity int, name string) error { ilen := len(input) if ilen != arity+1 { return fmt.Errorf("Argument count to '%s' must be %d, was %d", name, arity, ilen) } return nil } func evalDef(input []ast.AST, e env) (*ast.AST, error) { err := checkArity(input, 2, "=") if err != nil { return nil, err } variable := &input[1] if variable.Tag != ast.Symbol { variable, err = Eval(variable, e) if variable.Tag != ast.Symbol { return nil, fmt.Errorf("First argument to 'def' must be symbol, was %v", variable.Tag) } } if err != nil { return nil, err } sym := variable.Val.(string) evald, err := Eval(&input[2], e) if err != nil { return nil, err } e.values[sym] = evald return evald, nil } func arithCast(input ast.AST, e env) (float64, error) { evald, err := Eval(&input, e) if err != nil { return 0, err } if evald.Tag != ast.Num { return 0, fmt.Errorf("Cannot perform arithmetic on ", evald.Pretty()) } return evald.Val.(float64), nil } func evalArith(input []ast.AST, e env, fn (func(float64, float64) float64)) (*ast.AST, error) { ilen := len(input) if ilen < 3 { return nil, fmt.Errorf("Arithmetic functions take at least 2 arguments, got %d", ilen) } acc, err := arithCast(input[1], e) if err != nil { return nil, err } for _, elem := range input[2:] { val, err := arithCast(elem, e) if err != nil { return nil, err } acc = fn(acc, val) } res := ast.AST{ast.Num, acc} return &res, nil } func evalEq(input []ast.AST, e env) (*ast.AST, error) { err := checkArity(input, 2, "iso") if err != nil { return nil, err } x, err := Eval(&input[1], e) if err != nil { return nil, err } y, err := Eval(&input[2], e) if err != nil { return nil, err } if x.Tag != y.Tag { return &falseVal, nil } if x.Val == y.Val { return &trueVal, nil } return &falseVal, nil } func evalCons(input []ast.AST, e env) (*ast.AST, error) { err := checkArity(input, 2, "cons") if err != nil { return nil, err } lst, err := Eval(&input[2], e) if err != nil { return nil, err } if lst.Tag != ast.List { return nil, fmt.Errorf("Cannot cons to non-list %s", lst.Pretty()) } fst, err := Eval(&input[1], e) if err != nil { return nil, err } res := ast.AST{ast.List, append([]ast.AST{*fst}, lst.Val.([]ast.AST)...)} return &res, nil } func evalCar(input []ast.AST, e env) (*ast.AST, error) { err := checkArity(input, 1, "car") if err != nil { return nil, err } lst, err := Eval(&input[1], e) if err != nil { return nil, err } if lst.Tag != ast.List { return nil, fmt.Errorf("Cannot car from non-list %s", lst.Pretty()) } res := lst.Val.([]ast.AST)[0] return &res, nil } func evalCdr(input []ast.AST, e env) (*ast.AST, error) { err := checkArity(input, 1, "cdr") if err != nil { return nil, err } lst, err := Eval(&input[1], e) if err != nil { return nil, err } if lst.Tag != ast.List { return nil, fmt.Errorf("Cannot cdr from non-list %s", lst.Pretty()) } res := ast.AST{ast.List, lst.Val.([]ast.AST)[1:]} return &res, nil } func evalNull(input []ast.AST, e env) (*ast.AST, error) { err := checkArity(input, 1, "null") if err != nil { return nil, err } lst, err := Eval(&input[1], e) if err != nil { return nil, err } if lst.Tag != ast.List { return nil, fmt.Errorf("Cannot call null? on non-list %s", lst.Pretty()) } res := ast.AST{ast.Bool, len(lst.Val.([]ast.AST)) == 0} return &res, nil } func evalList(input *ast.AST, e env) (*ast.AST, error) { l := input.Val.([]ast.AST) head := l[0] if head.Tag != ast.Symbol { err := fmt.Errorf("Calling non-symbol: %s", head.Pretty()) return nil, err } sym := head.Val.(string) switch sym { case "=": return evalDef(l, e) case "iso": return evalEq(l, e) case "quote": res := l[1] return &res, nil case "cons": return evalCons(l, e) case "car": return evalCar(l, e) case "cdr": return evalCdr(l, e) case "null?": return evalNull(l, e) case "+": return evalArith(l, e, func(x float64, y float64) float64 { return x + y }) case "-": return evalArith(l, e, func(x float64, y float64) float64 { return x - y }) case "*": return evalArith(l, e, func(x float64, y float64) float64 { return x * y }) case "/": return evalArith(l, e, func(x float64, y float64) float64 { return x / y }) case "%": return evalArith(l, e, func(x float64, y float64) float64 { return float64(int64(x) % int64(y)) }) } return input, nil } func Eval(input *ast.AST, e env) (*ast.AST, error) { if (input.Tag == ast.List) { return evalList(input, e) } if (input.Tag == ast.Symbol) { val := input.Val.(string) var err error var res *ast.AST switch val { case "true": res = &trueVal case "false": res = &falseVal case "nil": res = &nilVal default: res, err = e.Lookup(val) } return res, err } return input, nil }