diff --git a/rx.py b/rx.py index 6c3b4bb..1030bb1 100644 --- a/rx.py +++ b/rx.py @@ -2,8 +2,14 @@ class Node(object): def derive(self, char): return NeverMatches + def match_end(self): return False + def can_match_more(self): return not self.match_end() -EmptyString = Node() + +class EmptyString(Node): + def match_end(self): return True + +EmptyString = EmptyString() NeverMatches = Node() @@ -35,6 +41,12 @@ class AlternationNode(Node): def derive(self, char): return new_alternation([alt.derive(char) for alt in self.alts]) + def match_end(self): + return any(alt.match_end() for alt in self.alts) + + def can_match_more(self): + return any(alt.can_match_more() for alt in self.alts) + class AnyNode(Node): def __init__(self, nxt): @@ -52,6 +64,9 @@ class RepetitionNode(Node): def derive(self, char): return new_alternation([self.head.derive(char), self.nxt.derive(char)]) + def match_end(self): return self.nxt.match_end() + def can_match_more(self): return True + class Or: def __init__(self, alts): @@ -134,7 +149,7 @@ class RE: for i, char in enumerate(s): state = state.derive(char) - if state is EmptyString: return i == ls-1 - if state is NeverMatches: return False + if state.match_end() and i == ls-1: return True + if state.match_end() and not state.can_match_more : return False return False