package eval // TODO: a lot of code duplication that can be solved by a type checking helper import ( "fmt" "strings" "github.com/cevaris/ordered_map" "github.com/hellerve/argos/ast" ) var trueVal = ast.AST{ast.Bool, true} var falseVal = ast.AST{ast.Bool, false} var nilVal = ast.AST{ast.List, []*ast.AST{}} func RootEnv() ast.Env { e := ast.ParentEnv() prims := map[string]ast.PrimFn{ "cons": evalCons, "car": evalCar, "cdr": evalCdr, "null?": evalNull, "pr": evalPr, "table": evalTable, "type": evalType, } for k, v := range prims { prim := ast.AST{ast.Prim, ast.Primitive{k, v}} e.Values[k] = &prim } return e } func checkArity(input []*ast.AST, arity int, name string) error { ilen := len(input) if ilen != arity { return fmt.Errorf("Argument count to '%s' must be %d, was %d", name, arity, ilen) } return nil } func evalDef(input []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(input, 2, "=") if err != nil { return nil, err } variable := input[0] if variable.Tag != ast.Symbol { // TODO destructuring assignment variable, err = Eval(variable, e) if err != nil { return nil, err } if variable.Tag != ast.Symbol { variable, err = Eval(input[1], e) return variable, err } } if err != nil { return nil, err } sym := variable.Val.(string) evald, err := Eval(input[1], e) if err != nil { return nil, err } e.Values[sym] = evald return evald, nil } func arithCast(input *ast.AST, e ast.Env) (float64, error) { evald, err := Eval(input, e) if err != nil { return 0, err } if evald.Tag != ast.Num { return 0, fmt.Errorf("Cannot perform arithmetic on ", evald.Pretty()) } return evald.Val.(float64), nil } func evalArith(input []*ast.AST, e ast.Env, fn func(float64, float64) float64) (*ast.AST, error) { ilen := len(input) if ilen < 2 { return nil, fmt.Errorf("Arithmetic functions take at least 2 arguments, got %d", ilen) } acc, err := arithCast(input[0], e) if err != nil { return nil, err } for _, elem := range input[1:] { val, err := arithCast(elem, e) if err != nil { return nil, err } acc = fn(acc, val) } res := ast.AST{ast.Num, acc} return &res, nil } func evalLog(input []*ast.AST, e ast.Env, fn func(float64, float64) bool) (*ast.AST, error) { ilen := len(input) if ilen < 2 { return nil, fmt.Errorf("Logic functions take at least 2 arguments, got %d", ilen) } old, err := arithCast(input[0], e) if err != nil { return nil, err } acc := true for _, elem := range input[1:] { val, err := arithCast(elem, e) if err != nil { return nil, err } acc = acc && fn(old, val) old = val } if acc { return &trueVal, nil } return &falseVal, nil } func evalEq(input []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(input, 2, "is") if err != nil { return nil, err } x, err := Eval(input[0], e) if err != nil { return nil, err } y, err := Eval(input[1], e) if err != nil { return nil, err } if x.Tag != y.Tag { return &falseVal, nil } if x.Val == y.Val { return &trueVal, nil } return &falseVal, nil } func evalCons(input []*ast.AST) (*ast.AST, error) { err := checkArity(input, 2, "cons") if err != nil { return nil, err } lst := input[1] if lst.Tag != ast.Quoted { return nil, fmt.Errorf("Cannot cons to non-list %s", lst.Pretty()) } fst := input[0] res := ast.AST{ast.List, append([]*ast.AST{fst}, lst.Val.(*ast.AST).Val.([]*ast.AST)...)} return &res, nil } func evalCar(input []*ast.AST) (*ast.AST, error) { err := checkArity(input, 1, "car") if err != nil { return nil, err } lst := input[0] if lst.Tag != ast.Quoted { return nil, fmt.Errorf("Cannot car from non-list %s", lst.Pretty()) } res := lst.Val.(*ast.AST).Val.([]*ast.AST)[0] return res, nil } func evalCdr(input []*ast.AST) (*ast.AST, error) { err := checkArity(input, 1, "cdr") if err != nil { return nil, err } lst := input[0] if lst.Tag != ast.Quoted { return nil, fmt.Errorf("Cannot cdr from non-list %s", lst.Pretty()) } res := ast.AST{ast.List, lst.Val.(*ast.AST).Val.([]*ast.AST)[1:]} return &res, nil } func evalNull(input []*ast.AST) (*ast.AST, error) { err := checkArity(input, 1, "null") if err != nil { return nil, err } lst := input[0] if lst.Tag != ast.Quoted { return nil, fmt.Errorf("Cannot call null? on non-list %s", lst.Pretty()) } res := ast.AST{ast.Bool, len(lst.Val.(*ast.AST).Val.([]*ast.AST)) == 0} return &res, nil } func evalFn(input []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(input, 2, "fn") if err != nil { return nil, err } args := input[0] if args.Tag != ast.List { return nil, fmt.Errorf("Cannot call fn with argument list %s", args.Pretty()) } var argsStr []string var rest *string = nil opt := ordered_map.NewOrderedMap() argslst := args.Val.([]*ast.AST) for i, a := range argslst { if a.Tag != ast.Symbol && a.Tag != ast.List { return nil, fmt.Errorf("Argument list cannot contain %s", a.Pretty()) } if a.Tag == ast.List { argv := a.Val.([]*ast.AST) argvln := len(argv) if (argvln != 3 && argvln != 2) || argv[0].Tag != ast.Symbol || argv[0].Val.(string) != "o" || argv[0].Tag != ast.Symbol { return nil, fmt.Errorf("Argument list cannot contain %s", a.Pretty()) } if argvln == 2 { opt.Set(argv[1].Val.(string), nil) } else if argvln == 3 { opt.Set(argv[1].Val.(string), argv[2]) } continue } val := a.Val.(string) // TODO: error handling if val == "." { str := argslst[i+1].Val.(string) rest = &str break } argsStr = append(argsStr, val) } body := input[1] fe := ast.NewEnv(&e) res := ast.AST{ast.Fn, ast.Func{argsStr, rest, opt, body, fe}} return &res, nil } func funcApply(f ast.Func, args []*ast.AST, e ast.Env) (*ast.AST, error) { plen := len(f.Params) alen := len(args) if plen != alen && !f.HasRest() && !f.HasOpt() { return nil, fmt.Errorf("Function expected %d arguments, was called with %d.", plen, alen) } for i, a := range f.Params { f.Env.Values[a] = args[i] } if f.HasRest() { lst := ast.AST{ast.List, args[plen:]} quoted := ast.AST{ast.Quoted, &lst} f.Env.Values[*f.Rest] = "ed } else { iter := f.Opt.IterFunc() i := plen for kv, ok := iter(); ok; kv, ok = iter() { if i < alen { f.Env.Values[kv.Key.(string)] = args[i] } else { f.Env.Values[kv.Key.(string)] = kv.Value.(*ast.AST) } i++ } } return Eval(f.Body, f.Env) } func evalPr(l []*ast.AST) (*ast.AST, error) { var toPrint []string for _, elem := range l { toPrint = append(toPrint, elem.Pretty()) } fmt.Print(strings.Join(toPrint, " ")) return &nilVal, nil } func evalIf(l []*ast.AST, e ast.Env) (*ast.AST, error) { i := 0 for i < len(l) { if i == len(l)-1 { return Eval(l[i], e) } cond, err := Eval(l[i], e) if err != nil { return nil, err } if cond == &trueVal { return Eval(l[i+1], e) } i += 2 } return &nilVal, nil } func evalDo(l []*ast.AST, e ast.Env) (*ast.AST, error) { var res *ast.AST var err error for _, elem := range l { res, err = Eval(elem, e) if err != nil { return nil, err } } return res, nil } func evalApply(l []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(l, 2, "apply") if err != nil { return nil, err } // TODO: error handling args, err := Eval(l[1], e) if err != nil { return nil, err } if args.Tag != ast.Quoted { return nil, fmt.Errorf("Argument 2 to apply must be a list, got %s", args.Pretty()) } argslst := args.Val.(*ast.AST).Val.([]*ast.AST) return evalList(append([]*ast.AST{l[0]}, argslst...), e) } func evalWhile(l []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(l, 2, "while") if err != nil { return nil, err } var res *ast.AST for { cond, err := Eval(l[0], e) if err != nil { return nil, err } if cond == &falseVal { break } res, err = Eval(l[1], e) if err != nil { return nil, err } } return res, nil } func evalTable(l []*ast.AST) (*ast.AST, error) { err := checkArity(l, 0, "table") if err != nil { return nil, err } res := ast.AST{ast.Table, make(map[*ast.AST]*ast.AST)} return &res, nil } func evalType(l []*ast.AST) (*ast.AST, error) { err := checkArity(l, 1, "type") if err != nil { return nil, err } return l[0].Type(), nil } func evalSymbolList(l []*ast.AST, e ast.Env) (*ast.AST, error) { sym := l[0].Val.(string) l = l[1:] // TODO: make more of these primitives switch sym { case "=": return evalDef(l, e) case "is": return evalEq(l, e) case "apply": return evalApply(l, e) case "quote": lst := l[0] res := ast.AST{ast.Quoted, lst} return &res, nil case "fn": return evalFn(l, e) case "if": return evalIf(l, e) case "do": return evalDo(l, e) case "while": return evalWhile(l, e) case "+": return evalArith(l, e, func(x float64, y float64) float64 { return x + y }) case "-": return evalArith(l, e, func(x float64, y float64) float64 { return x - y }) case "*": return evalArith(l, e, func(x float64, y float64) float64 { return x * y }) case "/": return evalArith(l, e, func(x float64, y float64) float64 { return x / y }) case "%": return evalArith(l, e, func(x float64, y float64) float64 { return float64(int64(x) % int64(y)) }) case "<": return evalLog(l, e, func(x float64, y float64) bool { return x < y }) case ">": return evalLog(l, e, func(x float64, y float64) bool { return x > y }) } res, err := e.Lookup(sym) if err != nil { return nil, err } return evalList(append([]*ast.AST{res}, l...), e) } func evalIdx(l []*ast.AST, e ast.Env) (*ast.AST, error) { head := l[0] err := checkArity(l, 2, "indexing") if err != nil { return nil, err } arg, err := Eval(l[1], e) if err != nil { return nil, err } switch head.Tag { case ast.Quoted: if arg.Tag != ast.Num { return nil, fmt.Errorf("Cannot index using %s", arg.Pretty()) } idx := int64(arg.Val.(float64)) var actualLen int64 val := head.Val.(*ast.AST).Val.([]*ast.AST) actualLen = int64(len(val)) actualIdx := idx if idx < 0 { actualIdx = actualLen + idx } if actualLen > actualIdx && actualIdx > 0 { return val[actualIdx], nil } return nil, fmt.Errorf("Out of bounds access (idx %d at %s)", idx, head.Pretty()) case ast.String: if arg.Tag != ast.Num { return nil, fmt.Errorf("Cannot index using %s", arg.Pretty()) } idx := int64(arg.Val.(float64)) var actualLen int64 val := []rune(head.Val.(string)) actualLen = int64(len(val)) actualIdx := idx if idx < 0 { actualIdx = actualLen + idx } if actualLen > actualIdx && actualIdx > 0 { res := ast.AST{ast.Char, val[actualIdx]} return &res, nil } return nil, fmt.Errorf("Out of bounds access (idx %d at %s)", idx, head.Pretty()) case ast.Table: val := head.Val.(map[*ast.AST]*ast.AST) res, ok := val[arg] if !ok { return &nilVal, nil } return res, nil } return nil, fmt.Errorf("Cannot index %s", head.Pretty()) } func primCall(f ast.Primitive, l []*ast.AST, e ast.Env) (*ast.AST, error) { var args []*ast.AST for _, a := range l { res, err := Eval(a, e) if err != nil { return nil, err } args = append(args, res) } return f.Fn(args) } func evalList(l []*ast.AST, e ast.Env) (*ast.AST, error) { head := l[0] var err error if head.Tag == ast.List { head, err = Eval(head, e) if err != nil { return nil, err } } switch head.Tag { case ast.Symbol: return evalSymbolList(l, e) case ast.Fn: return funcApply(head.Val.(ast.Func), l[1:], e) case ast.Prim: return primCall(head.Val.(ast.Primitive), l[1:], e) case ast.String, ast.Quoted, ast.Table: return evalIdx(append([]*ast.AST{head}, l[1:]...), e) } return nil, fmt.Errorf("Cannot perform call on %s", head.Pretty()) } func Eval(input *ast.AST, e ast.Env) (*ast.AST, error) { if input.Tag == ast.List { l := input.Val.([]*ast.AST) return evalList(l, e) } if input.Tag == ast.Symbol { val := input.Val.(string) var err error var res *ast.AST switch val { case "true": res = &trueVal case "false": res = &falseVal case "nil": res = &nilVal default: res, err = e.Lookup(val) } return res, err } return input, nil }