Files
argos/eval/eval.go

773 lines
15 KiB
Go

package eval
// TODO: a lot of code duplication that can be solved by a type checking helper
import (
"fmt"
"strings"
"github.com/cevaris/ordered_map"
"github.com/hellerve/argos/ast"
)
var trueVal = ast.AST{ast.Bool, true}
var falseVal = ast.AST{ast.Bool, false}
var nilVal = ast.AST{ast.List, []*ast.AST{}}
func RootEnv() ast.Env {
e := ast.ParentEnv()
prims := map[string]ast.PrimFn{
"cons": evalCons,
"car": evalCar,
"cdr": evalCdr,
"null?": evalNull,
"pr": evalPr,
"write": evalWrite,
"table": evalTable,
"type": evalType,
"len": evalLen,
}
for k, v := range prims {
prim := ast.AST{ast.Prim, ast.Primitive{k, v}}
e.Values[k] = &prim
}
return e
}
func checkArity(input []*ast.AST, arity int, name string) error {
ilen := len(input)
if ilen != arity {
return fmt.Errorf("Argument count to '%s' must be %d, was %d", name, arity, ilen)
}
return nil
}
func evalDef(input []*ast.AST, e ast.Env) (*ast.AST, error) {
err := checkArity(input, 2, "=")
if err != nil {
return nil, err
}
variable := input[0]
if variable.Tag != ast.Symbol {
// TODO destructuring assignment
variable, err = Eval(variable, e)
if err != nil {
return nil, err
}
if variable.Tag != ast.Symbol {
variable, err = Eval(input[1], e)
return variable, err
}
}
if err != nil {
return nil, err
}
sym := variable.Val.(string)
evald, err := Eval(input[1], e)
if err != nil {
return nil, err
}
e.Values[sym] = evald
return evald, nil
}
func arithCast(input *ast.AST, e ast.Env) (float64, error) {
evald, err := Eval(input, e)
if err != nil {
return 0, err
}
if evald.Tag != ast.Num {
return 0, fmt.Errorf("Cannot perform arithmetic on ", evald.Pretty())
}
return evald.Val.(float64), nil
}
func evalArith(input []*ast.AST, e ast.Env, fn func(float64, float64) float64) (*ast.AST, error) {
ilen := len(input)
if ilen < 2 {
return nil, fmt.Errorf("Arithmetic functions take at least 2 arguments, got %d", ilen)
}
acc, err := arithCast(input[0], e)
if err != nil {
return nil, err
}
for _, elem := range input[1:] {
val, err := arithCast(elem, e)
if err != nil {
return nil, err
}
acc = fn(acc, val)
}
res := ast.AST{ast.Num, acc}
return &res, nil
}
func evalLog(input []*ast.AST, e ast.Env, fn func(float64, float64) bool) (*ast.AST, error) {
ilen := len(input)
if ilen < 2 {
return nil, fmt.Errorf("Logic functions take at least 2 arguments, got %d", ilen)
}
old, err := arithCast(input[0], e)
if err != nil {
return nil, err
}
acc := true
for _, elem := range input[1:] {
val, err := arithCast(elem, e)
if err != nil {
return nil, err
}
acc = acc && fn(old, val)
old = val
}
if acc {
return &trueVal, nil
}
return &falseVal, nil
}
func evalEq(input []*ast.AST, e ast.Env) (*ast.AST, error) {
err := checkArity(input, 2, "is")
if err != nil {
return nil, err
}
x, err := Eval(input[0], e)
if err != nil {
return nil, err
}
y, err := Eval(input[1], e)
if err != nil {
return nil, err
}
if x.Tag != y.Tag {
return &falseVal, nil
}
if x.Val == y.Val {
return &trueVal, nil
}
return &falseVal, nil
}
func evalCons(input []*ast.AST) (*ast.AST, error) {
err := checkArity(input, 2, "cons")
if err != nil {
return nil, err
}
lst := input[1]
if lst.Tag != ast.Quoted {
return nil, fmt.Errorf("Cannot cons to non-list %s", lst.Pretty())
}
fst := input[0]
res := ast.AST{ast.List, append([]*ast.AST{fst}, lst.Val.(*ast.AST).Val.([]*ast.AST)...)}
return &res, nil
}
func evalCar(input []*ast.AST) (*ast.AST, error) {
err := checkArity(input, 1, "car")
if err != nil {
return nil, err
}
lst := input[0]
if lst.Tag != ast.Quoted {
return nil, fmt.Errorf("Cannot car from non-list %s", lst.Pretty())
}
res := lst.Val.(*ast.AST).Val.([]*ast.AST)[0]
return res, nil
}
func evalCdr(input []*ast.AST) (*ast.AST, error) {
err := checkArity(input, 1, "cdr")
if err != nil {
return nil, err
}
lst := input[0]
if lst.Tag != ast.Quoted {
return nil, fmt.Errorf("Cannot cdr from non-list %s", lst.Pretty())
}
res := ast.AST{ast.List, lst.Val.(*ast.AST).Val.([]*ast.AST)[1:]}
return &res, nil
}
func evalNull(input []*ast.AST) (*ast.AST, error) {
err := checkArity(input, 1, "null")
if err != nil {
return nil, err
}
lst := input[0]
if lst.Tag != ast.Quoted {
return nil, fmt.Errorf("Cannot call null? on non-list %s", lst.Pretty())
}
res := ast.AST{ast.Bool, len(lst.Val.(*ast.AST).Val.([]*ast.AST)) == 0}
return &res, nil
}
func evalFn(input []*ast.AST, e ast.Env) (*ast.AST, error) {
err := checkArity(input, 2, "fn")
if err != nil {
return nil, err
}
args := input[0]
if args.Tag != ast.List {
return nil, fmt.Errorf("Cannot call fn with argument list %s", args.Pretty())
}
var argsStr []string
var rest *string = nil
opt := ordered_map.NewOrderedMap()
argslst := args.Val.([]*ast.AST)
for i, a := range argslst {
if a.Tag != ast.Symbol && a.Tag != ast.List {
return nil, fmt.Errorf("Argument list cannot contain %s", a.Pretty())
}
if a.Tag == ast.List {
argv := a.Val.([]*ast.AST)
argvln := len(argv)
if (argvln != 3 && argvln != 2) || argv[0].Tag != ast.Symbol || argv[0].Val.(string) != "o" || argv[0].Tag != ast.Symbol {
return nil, fmt.Errorf("Argument list cannot contain %s", a.Pretty())
}
if argvln == 2 {
opt.Set(argv[1].Val.(string), &nilVal)
} else if argvln == 3 {
opt.Set(argv[1].Val.(string), argv[2])
}
continue
}
val := a.Val.(string)
// TODO: error handling
if val == "." {
str := argslst[i+1].Val.(string)
rest = &str
break
}
argsStr = append(argsStr, val)
}
body := input[1]
fe := ast.NewEnv(&e)
res := ast.AST{ast.Fn, ast.Func{argsStr, rest, opt, body, fe}}
return &res, nil
}
func funcApply(f ast.Func, args []*ast.AST, e ast.Env) (*ast.AST, error) {
plen := len(f.Params)
alen := len(args)
if plen != alen && !f.HasRest() && !f.HasOpt() {
return nil, fmt.Errorf("Function expected %d arguments, was called with %d.", plen, alen)
}
for i, a := range f.Params {
evald, err := Eval(args[i], e)
if err != nil {
return nil, err
}
f.Env.Values[a] = evald
}
if f.HasRest() {
var l []*ast.AST
for _, arg := range args[plen:] {
evald, err := Eval(arg, e)
if err != nil {
return nil, err
}
l = append(l, evald)
}
lst := ast.AST{ast.List, l}
quoted := ast.AST{ast.Quoted, &lst}
f.Env.Values[*f.Rest] = &quoted
} else {
iter := f.Opt.IterFunc()
i := plen
for kv, ok := iter(); ok; kv, ok = iter() {
if i < alen {
evald, err := Eval(args[i], e)
if err != nil {
return nil, err
}
f.Env.Values[kv.Key.(string)] = evald
} else {
f.Env.Values[kv.Key.(string)] = kv.Value.(*ast.AST)
}
i++
}
}
return Eval(f.Body, f.Env)
}
func evalPr(l []*ast.AST) (*ast.AST, error) {
var toPrint []string
for _, elem := range l {
toPrint = append(toPrint, elem.Pretty())
}
fmt.Print(strings.Join(toPrint, " "))
return &nilVal, nil
}
func evalWrite(l []*ast.AST) (*ast.AST, error) {
var toPrint []string
for _, elem := range l {
toPrint = append(toPrint, elem.String())
}
fmt.Print(strings.Join(toPrint, " "))
return &nilVal, nil
}
func evalIf(l []*ast.AST, e ast.Env) (*ast.AST, error) {
i := 0
for i < len(l) {
if i == len(l)-1 {
return Eval(l[i], e)
}
cond, err := Eval(l[i], e)
if err != nil {
return nil, err
}
if cond == &trueVal {
return Eval(l[i+1], e)
}
i += 2
}
return &nilVal, nil
}
func evalDo(l []*ast.AST, e ast.Env) (*ast.AST, error) {
var res *ast.AST
var err error
for _, elem := range l {
res, err = Eval(elem, e)
if err != nil {
return nil, err
}
}
return res, nil
}
func evalApply(l []*ast.AST, e ast.Env) (*ast.AST, error) {
err := checkArity(l, 2, "apply")
if err != nil {
return nil, err
}
// TODO: error handling
args, err := Eval(l[1], e)
if err != nil {
return nil, err
}
if args.Tag != ast.Quoted {
return nil, fmt.Errorf("Argument 2 to apply must be a list, got %s", args.Pretty())
}
argslst := args.Val.(*ast.AST).Val.([]*ast.AST)
return evalList(append([]*ast.AST{l[0]}, argslst...), e)
}
func evalWhile(l []*ast.AST, e ast.Env) (*ast.AST, error) {
err := checkArity(l, 2, "while")
if err != nil {
return nil, err
}
var res *ast.AST
for {
cond, err := Eval(l[0], e)
if err != nil {
return nil, err
}
if cond == &falseVal {
break
}
res, err = Eval(l[1], e)
if err != nil {
return nil, err
}
}
return res, nil
}
func evalTable(l []*ast.AST) (*ast.AST, error) {
err := checkArity(l, 0, "table")
if err != nil {
return nil, err
}
res := ast.AST{ast.Table, make(map[*ast.AST]*ast.AST)}
return &res, nil
}
func evalType(l []*ast.AST) (*ast.AST, error) {
err := checkArity(l, 1, "type")
if err != nil {
return nil, err
}
return l[0].Type(), nil
}
func evalLen(l []*ast.AST) (*ast.AST, error) {
err := checkArity(l, 1, "len")
if err != nil {
return nil, err
}
cnt := 0
elem := l[0]
switch elem.Tag {
case ast.Quoted:
// TODO: error handling
cnt = len(elem.Val.(*ast.AST).Val.([]*ast.AST))
case ast.String:
cnt = len(elem.Val.(string))
case ast.Table:
cnt = len(elem.Val.(map[*ast.AST]*ast.AST))
default:
return nil, fmt.Errorf("Cannot get length of %s", elem.Pretty())
}
res := ast.AST{ast.Num, float64(cnt)}
return &res, nil
}
func evalMapTable(l []*ast.AST, e ast.Env) (*ast.AST, error) {
err := checkArity(l, 2, "maptable")
if err != nil {
return nil, err
}
f, err := Eval(l[0], e)
if err != nil {
return nil, err
}
maybeLst, err := Eval(l[1], e)
if err != nil {
return nil, err
}
if maybeLst.Tag != ast.Table {
return nil, fmt.Errorf("Second argument to maptable must be table")
}
lst := maybeLst.Val.(map[*ast.AST]*ast.AST)
var res []*ast.AST
for k, v := range lst {
toEval := ast.AST{ast.List, []*ast.AST{f, k, v}}
evald, err := Eval(&toEval, e)
if err != nil {
return nil, err
}
res = append(res, evald)
}
reslst := ast.AST{ast.List, res}
quoted := ast.AST{ast.Quoted, &reslst}
return &quoted, nil
}
func evalQuasiquote(l []*ast.AST, e ast.Env) (*ast.AST, error) {
err := checkArity(l, 1, "quasiquote")
if err != nil {
return nil, err
}
arg := l[0]
if arg.Tag != ast.List {
return nil, fmt.Errorf("Argument to quasiquote needs to be list, got %s", arg.Pretty())
}
var res []*ast.AST
for _, elem := range arg.Val.([]*ast.AST) {
if elem.Tag != ast.List {
res = append(res, elem)
continue
}
lst := elem.Val.([]*ast.AST)
if lst[0].Tag != ast.Symbol && lst[0].Val.(string) == "unquote" {
res = append(res, elem)
continue
}
if len(lst) != 2 {
return nil, fmt.Errorf("Unquote takes one argument.")
}
value, err := Eval(lst[1], e)
if err != nil {
return nil, err
}
res = append(res, value)
}
lst := ast.AST{ast.List, res}
quotedLst := ast.AST{ast.Quoted, &lst}
return &quotedLst, nil
}
func evalSymbolList(l []*ast.AST, e ast.Env) (*ast.AST, error) {
sym := l[0].Val.(string)
l = l[1:]
// TODO: make more of these primitives
switch sym {
case "=":
return evalDef(l, e)
case "is":
return evalEq(l, e)
case "apply":
return evalApply(l, e)
case "quote":
lst := l[0]
res := ast.AST{ast.Quoted, lst}
return &res, nil
case "quasiquote":
return evalQuasiquote(l, e)
case "fn":
return evalFn(l, e)
case "if":
return evalIf(l, e)
case "do":
return evalDo(l, e)
case "while":
return evalWhile(l, e)
case "+":
return evalArith(l, e, func(x float64, y float64) float64 { return x + y })
case "-":
return evalArith(l, e, func(x float64, y float64) float64 { return x - y })
case "*":
return evalArith(l, e, func(x float64, y float64) float64 { return x * y })
case "/":
return evalArith(l, e, func(x float64, y float64) float64 { return x / y })
case "%":
return evalArith(l, e, func(x float64, y float64) float64 { return float64(int64(x) % int64(y)) })
case "<":
return evalLog(l, e, func(x float64, y float64) bool { return x < y })
case ">":
return evalLog(l, e, func(x float64, y float64) bool { return x > y })
case "maptable":
return evalMapTable(l, e)
}
res, err := e.Lookup(sym)
if err != nil {
return nil, err
}
return evalList(append([]*ast.AST{res}, l...), e)
}
func evalIdx(l []*ast.AST, e ast.Env) (*ast.AST, error) {
head := l[0]
err := checkArity(l, 2, "indexing")
if err != nil {
return nil, err
}
arg, err := Eval(l[1], e)
if err != nil {
return nil, err
}
switch head.Tag {
case ast.Quoted:
if arg.Tag != ast.Num {
return nil, fmt.Errorf("Cannot index using %s", arg.Pretty())
}
idx := int64(arg.Val.(float64))
var actualLen int64
val := head.Val.(*ast.AST).Val.([]*ast.AST)
actualLen = int64(len(val))
actualIdx := idx
if idx < 0 {
actualIdx = actualLen + idx
}
if actualLen > actualIdx && actualIdx > 0 {
return val[actualIdx], nil
}
return nil, fmt.Errorf("Out of bounds access (idx %d at %s)", idx, head.Pretty())
case ast.String:
if arg.Tag != ast.Num {
return nil, fmt.Errorf("Cannot index using %s", arg.Pretty())
}
idx := int64(arg.Val.(float64))
var actualLen int64
val := []rune(head.Val.(string))
actualLen = int64(len(val))
actualIdx := idx
if idx < 0 {
actualIdx = actualLen + idx
}
if actualLen > actualIdx && actualIdx > 0 {
res := ast.AST{ast.Char, val[actualIdx]}
return &res, nil
}
return nil, fmt.Errorf("Out of bounds access (idx %d at %s)", idx, head.Pretty())
case ast.Table:
val := head.Val.(map[*ast.AST]*ast.AST)
res, ok := val[arg]
if !ok {
return &nilVal, nil
}
return res, nil
}
return nil, fmt.Errorf("Cannot index %s", head.Pretty())
}
func primCall(f ast.Primitive, l []*ast.AST, e ast.Env) (*ast.AST, error) {
var args []*ast.AST
for _, a := range l {
res, err := Eval(a, e)
if err != nil {
return nil, err
}
args = append(args, res)
}
return f.Fn(args)
}
func evalList(l []*ast.AST, e ast.Env) (*ast.AST, error) {
head := l[0]
var err error
if head.Tag == ast.List {
head, err = Eval(head, e)
if err != nil {
return nil, err
}
}
switch head.Tag {
case ast.Symbol:
return evalSymbolList(l, e)
case ast.Fn:
return funcApply(head.Val.(ast.Func), l[1:], e)
case ast.Prim:
return primCall(head.Val.(ast.Primitive), l[1:], e)
case ast.String, ast.Quoted, ast.Table:
return evalIdx(append([]*ast.AST{head}, l[1:]...), e)
}
return nil, fmt.Errorf("Cannot perform call on %s", head.Pretty())
}
func Eval(input *ast.AST, e ast.Env) (*ast.AST, error) {
if input.Tag == ast.List {
l := input.Val.([]*ast.AST)
return evalList(l, e)
}
if input.Tag == ast.Symbol {
val := input.Val.(string)
var err error
var res *ast.AST
switch val {
case "true":
res = &trueVal
case "false":
res = &falseVal
case "nil":
res = &nilVal
default:
res, err = e.Lookup(val)
}
return res, err
}
return input, nil
}