package eval // TODO: a lot of code duplication that can be solved by a type checking helper import ( "fmt" "strings" "github.com/cevaris/ordered_map" "github.com/hellerve/argos/ast" ) var trueVal = ast.AST{ast.Bool, true} var falseVal = ast.AST{ast.Bool, false} var nilVal = ast.AST{ast.List, []*ast.AST{}} func RootEnv() ast.Env { e := ast.ParentEnv() prims := map[string]ast.PrimFn{ "cons": evalCons, "car": evalCar, "cdr": evalCdr, "null?": evalNull, "pr": evalPr, "write": evalWrite, "table": evalTable, "type": evalType, "len": evalLen, } for k, v := range prims { prim := ast.AST{ast.Prim, ast.Primitive{k, v}} e.Values[k] = &prim } return e } func checkArity(input []*ast.AST, arity int, name string) error { ilen := len(input) if ilen != arity { return fmt.Errorf("Argument count to '%s' must be %d, was %d", name, arity, ilen) } return nil } func evalDef(input []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(input, 2, "=") if err != nil { return nil, err } variable := input[0] if variable.Tag != ast.Symbol { // TODO destructuring assignment variable, err = Eval(variable, e) if err != nil { return nil, err } if variable.Tag != ast.Symbol { variable, err = Eval(input[1], e) return variable, err } } if err != nil { return nil, err } sym := variable.Val.(string) evald, err := Eval(input[1], e) if err != nil { return nil, err } e.Values[sym] = evald return evald, nil } func arithCast(input *ast.AST, e ast.Env) (float64, error) { evald, err := Eval(input, e) if err != nil { return 0, err } if evald.Tag != ast.Num { return 0, fmt.Errorf("Cannot perform arithmetic on ", evald.Pretty()) } return evald.Val.(float64), nil } func evalArith(input []*ast.AST, e ast.Env, fn func(float64, float64) float64) (*ast.AST, error) { ilen := len(input) if ilen < 2 { return nil, fmt.Errorf("Arithmetic functions take at least 2 arguments, got %d", ilen) } acc, err := arithCast(input[0], e) if err != nil { return nil, err } for _, elem := range input[1:] { val, err := arithCast(elem, e) if err != nil { return nil, err } acc = fn(acc, val) } res := ast.AST{ast.Num, acc} return &res, nil } func evalLog(input []*ast.AST, e ast.Env, fn func(float64, float64) bool) (*ast.AST, error) { ilen := len(input) if ilen < 2 { return nil, fmt.Errorf("Logic functions take at least 2 arguments, got %d", ilen) } old, err := arithCast(input[0], e) if err != nil { return nil, err } acc := true for _, elem := range input[1:] { val, err := arithCast(elem, e) if err != nil { return nil, err } acc = acc && fn(old, val) old = val } if acc { return &trueVal, nil } return &falseVal, nil } func evalEq(input []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(input, 2, "is") if err != nil { return nil, err } x, err := Eval(input[0], e) if err != nil { return nil, err } y, err := Eval(input[1], e) if err != nil { return nil, err } if x.Tag != y.Tag { return &falseVal, nil } if x.Val == y.Val { return &trueVal, nil } return &falseVal, nil } func evalCons(input []*ast.AST) (*ast.AST, error) { err := checkArity(input, 2, "cons") if err != nil { return nil, err } lst := input[1] if lst.Tag != ast.Quoted { return nil, fmt.Errorf("Cannot cons to non-list %s", lst.Pretty()) } fst := input[0] res := ast.AST{ast.List, append([]*ast.AST{fst}, lst.Val.(*ast.AST).Val.([]*ast.AST)...)} return &res, nil } func evalCar(input []*ast.AST) (*ast.AST, error) { err := checkArity(input, 1, "car") if err != nil { return nil, err } lst := input[0] if lst.Tag != ast.Quoted { return nil, fmt.Errorf("Cannot car from non-list %s", lst.Pretty()) } res := lst.Val.(*ast.AST).Val.([]*ast.AST)[0] return res, nil } func evalCdr(input []*ast.AST) (*ast.AST, error) { err := checkArity(input, 1, "cdr") if err != nil { return nil, err } lst := input[0] if lst.Tag != ast.Quoted { return nil, fmt.Errorf("Cannot cdr from non-list %s", lst.Pretty()) } res := ast.AST{ast.List, lst.Val.(*ast.AST).Val.([]*ast.AST)[1:]} return &res, nil } func evalNull(input []*ast.AST) (*ast.AST, error) { err := checkArity(input, 1, "null") if err != nil { return nil, err } lst := input[0] if lst.Tag != ast.Quoted { return nil, fmt.Errorf("Cannot call null? on non-list %s", lst.Pretty()) } res := ast.AST{ast.Bool, len(lst.Val.(*ast.AST).Val.([]*ast.AST)) == 0} return &res, nil } func evalFn(input []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(input, 2, "fn") if err != nil { return nil, err } args := input[0] if args.Tag != ast.List { return nil, fmt.Errorf("Cannot call fn with argument list %s", args.Pretty()) } var argsStr []string var rest *string = nil opt := ordered_map.NewOrderedMap() argslst := args.Val.([]*ast.AST) for i, a := range argslst { if a.Tag != ast.Symbol && a.Tag != ast.List { return nil, fmt.Errorf("Argument list cannot contain %s", a.Pretty()) } if a.Tag == ast.List { argv := a.Val.([]*ast.AST) argvln := len(argv) if (argvln != 3 && argvln != 2) || argv[0].Tag != ast.Symbol || argv[0].Val.(string) != "o" || argv[0].Tag != ast.Symbol { return nil, fmt.Errorf("Argument list cannot contain %s", a.Pretty()) } if argvln == 2 { opt.Set(argv[1].Val.(string), &nilVal) } else if argvln == 3 { opt.Set(argv[1].Val.(string), argv[2]) } continue } val := a.Val.(string) // TODO: error handling if val == "." { str := argslst[i+1].Val.(string) rest = &str break } argsStr = append(argsStr, val) } body := input[1] fe := ast.NewEnv(&e) res := ast.AST{ast.Fn, ast.Func{argsStr, rest, opt, body, fe}} return &res, nil } func funcApply(f ast.Func, args []*ast.AST, e ast.Env) (*ast.AST, error) { plen := len(f.Params) alen := len(args) if plen != alen && !f.HasRest() && !f.HasOpt() { return nil, fmt.Errorf("Function expected %d arguments, was called with %d.", plen, alen) } for i, a := range f.Params { evald, err := Eval(args[i], e) if err != nil { return nil, err } f.Env.Values[a] = evald } if f.HasRest() { var l []*ast.AST for _, arg := range args[plen:] { evald, err := Eval(arg, e) if err != nil { return nil, err } l = append(l, evald) } lst := ast.AST{ast.List, l} quoted := ast.AST{ast.Quoted, &lst} f.Env.Values[*f.Rest] = "ed } else { iter := f.Opt.IterFunc() i := plen for kv, ok := iter(); ok; kv, ok = iter() { if i < alen { evald, err := Eval(args[i], e) if err != nil { return nil, err } f.Env.Values[kv.Key.(string)] = evald } else { f.Env.Values[kv.Key.(string)] = kv.Value.(*ast.AST) } i++ } } return Eval(f.Body, f.Env) } func evalPr(l []*ast.AST) (*ast.AST, error) { var toPrint []string for _, elem := range l { toPrint = append(toPrint, elem.Pretty()) } fmt.Print(strings.Join(toPrint, " ")) return &nilVal, nil } func evalWrite(l []*ast.AST) (*ast.AST, error) { var toPrint []string for _, elem := range l { toPrint = append(toPrint, elem.String()) } fmt.Print(strings.Join(toPrint, " ")) return &nilVal, nil } func evalIf(l []*ast.AST, e ast.Env) (*ast.AST, error) { i := 0 for i < len(l) { if i == len(l)-1 { return Eval(l[i], e) } cond, err := Eval(l[i], e) if err != nil { return nil, err } if cond == &trueVal { return Eval(l[i+1], e) } i += 2 } return &nilVal, nil } func evalDo(l []*ast.AST, e ast.Env) (*ast.AST, error) { var res *ast.AST var err error for _, elem := range l { res, err = Eval(elem, e) if err != nil { return nil, err } } return res, nil } func evalApply(l []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(l, 2, "apply") if err != nil { return nil, err } // TODO: error handling args, err := Eval(l[1], e) if err != nil { return nil, err } if args.Tag != ast.Quoted { return nil, fmt.Errorf("Argument 2 to apply must be a list, got %s", args.Pretty()) } argslst := args.Val.(*ast.AST).Val.([]*ast.AST) return evalList(append([]*ast.AST{l[0]}, argslst...), e) } func evalWhile(l []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(l, 2, "while") if err != nil { return nil, err } var res *ast.AST for { cond, err := Eval(l[0], e) if err != nil { return nil, err } if cond == &falseVal { break } res, err = Eval(l[1], e) if err != nil { return nil, err } } return res, nil } func evalTable(l []*ast.AST) (*ast.AST, error) { err := checkArity(l, 0, "table") if err != nil { return nil, err } res := ast.AST{ast.Table, make(map[*ast.AST]*ast.AST)} return &res, nil } func evalType(l []*ast.AST) (*ast.AST, error) { err := checkArity(l, 1, "type") if err != nil { return nil, err } return l[0].Type(), nil } func evalLen(l []*ast.AST) (*ast.AST, error) { err := checkArity(l, 1, "len") if err != nil { return nil, err } cnt := 0 elem := l[0] switch elem.Tag { case ast.Quoted: // TODO: error handling cnt = len(elem.Val.(*ast.AST).Val.([]*ast.AST)) case ast.String: cnt = len(elem.Val.(string)) case ast.Table: cnt = len(elem.Val.(map[*ast.AST]*ast.AST)) default: return nil, fmt.Errorf("Cannot get length of %s", elem.Pretty()) } res := ast.AST{ast.Num, float64(cnt)} return &res, nil } func evalMapTable(l []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(l, 2, "maptable") if err != nil { return nil, err } f, err := Eval(l[0], e) if err != nil { return nil, err } maybeLst, err := Eval(l[1], e) if err != nil { return nil, err } if maybeLst.Tag != ast.Table { return nil, fmt.Errorf("Second argument to maptable must be table") } lst := maybeLst.Val.(map[*ast.AST]*ast.AST) var res []*ast.AST for k, v := range lst { toEval := ast.AST{ast.List, []*ast.AST{f, k, v}} evald, err := Eval(&toEval, e) if err != nil { return nil, err } res = append(res, evald) } reslst := ast.AST{ast.List, res} quoted := ast.AST{ast.Quoted, &reslst} return "ed, nil } func evalQuasiquote(l []*ast.AST, e ast.Env) (*ast.AST, error) { err := checkArity(l, 1, "quasiquote") if err != nil { return nil, err } arg := l[0] if arg.Tag != ast.List { return nil, fmt.Errorf("Argument to quasiquote needs to be list, got %s", arg.Pretty()) } var res []*ast.AST for _, elem := range arg.Val.([]*ast.AST) { if elem.Tag != ast.List { res = append(res, elem) continue } lst := elem.Val.([]*ast.AST) if lst[0].Tag != ast.Symbol && lst[0].Val.(string) == "unquote" { res = append(res, elem) continue } if len(lst) != 2 { return nil, fmt.Errorf("Unquote takes one argument.") } value, err := Eval(lst[1], e) if err != nil { return nil, err } res = append(res, value) } lst := ast.AST{ast.List, res} quotedLst := ast.AST{ast.Quoted, &lst} return "edLst, nil } func evalSymbolList(l []*ast.AST, e ast.Env) (*ast.AST, error) { sym := l[0].Val.(string) l = l[1:] // TODO: make more of these primitives switch sym { case "=": return evalDef(l, e) case "is": return evalEq(l, e) case "apply": return evalApply(l, e) case "quote": lst := l[0] res := ast.AST{ast.Quoted, lst} return &res, nil case "quasiquote": return evalQuasiquote(l, e) case "fn": return evalFn(l, e) case "if": return evalIf(l, e) case "do": return evalDo(l, e) case "while": return evalWhile(l, e) case "+": return evalArith(l, e, func(x float64, y float64) float64 { return x + y }) case "-": return evalArith(l, e, func(x float64, y float64) float64 { return x - y }) case "*": return evalArith(l, e, func(x float64, y float64) float64 { return x * y }) case "/": return evalArith(l, e, func(x float64, y float64) float64 { return x / y }) case "%": return evalArith(l, e, func(x float64, y float64) float64 { return float64(int64(x) % int64(y)) }) case "<": return evalLog(l, e, func(x float64, y float64) bool { return x < y }) case ">": return evalLog(l, e, func(x float64, y float64) bool { return x > y }) case "maptable": return evalMapTable(l, e) } res, err := e.Lookup(sym) if err != nil { return nil, err } return evalList(append([]*ast.AST{res}, l...), e) } func evalIdx(l []*ast.AST, e ast.Env) (*ast.AST, error) { head := l[0] err := checkArity(l, 2, "indexing") if err != nil { return nil, err } arg, err := Eval(l[1], e) if err != nil { return nil, err } switch head.Tag { case ast.Quoted: if arg.Tag != ast.Num { return nil, fmt.Errorf("Cannot index using %s", arg.Pretty()) } idx := int64(arg.Val.(float64)) var actualLen int64 val := head.Val.(*ast.AST).Val.([]*ast.AST) actualLen = int64(len(val)) actualIdx := idx if idx < 0 { actualIdx = actualLen + idx } if actualLen > actualIdx && actualIdx > 0 { return val[actualIdx], nil } return nil, fmt.Errorf("Out of bounds access (idx %d at %s)", idx, head.Pretty()) case ast.String: if arg.Tag != ast.Num { return nil, fmt.Errorf("Cannot index using %s", arg.Pretty()) } idx := int64(arg.Val.(float64)) var actualLen int64 val := []rune(head.Val.(string)) actualLen = int64(len(val)) actualIdx := idx if idx < 0 { actualIdx = actualLen + idx } if actualLen > actualIdx && actualIdx > 0 { res := ast.AST{ast.Char, val[actualIdx]} return &res, nil } return nil, fmt.Errorf("Out of bounds access (idx %d at %s)", idx, head.Pretty()) case ast.Table: val := head.Val.(map[*ast.AST]*ast.AST) res, ok := val[arg] if !ok { return &nilVal, nil } return res, nil } return nil, fmt.Errorf("Cannot index %s", head.Pretty()) } func primCall(f ast.Primitive, l []*ast.AST, e ast.Env) (*ast.AST, error) { var args []*ast.AST for _, a := range l { res, err := Eval(a, e) if err != nil { return nil, err } args = append(args, res) } return f.Fn(args) } func evalList(l []*ast.AST, e ast.Env) (*ast.AST, error) { head := l[0] var err error if head.Tag == ast.List { head, err = Eval(head, e) if err != nil { return nil, err } } switch head.Tag { case ast.Symbol: return evalSymbolList(l, e) case ast.Fn: return funcApply(head.Val.(ast.Func), l[1:], e) case ast.Prim: return primCall(head.Val.(ast.Primitive), l[1:], e) case ast.String, ast.Quoted, ast.Table: return evalIdx(append([]*ast.AST{head}, l[1:]...), e) } return nil, fmt.Errorf("Cannot perform call on %s", head.Pretty()) } func Eval(input *ast.AST, e ast.Env) (*ast.AST, error) { if input.Tag == ast.List { l := input.Val.([]*ast.AST) return evalList(l, e) } if input.Tag == ast.Symbol { val := input.Val.(string) var err error var res *ast.AST switch val { case "true": res = &trueVal case "false": res = &falseVal case "nil": res = &nilVal default: res, err = e.Lookup(val) } return res, err } return input, nil }