156 lines
3.3 KiB
Python
156 lines
3.3 KiB
Python
class Node(object):
|
|
def derive(self, char):
|
|
return NeverMatches
|
|
|
|
def match_end(self): return False
|
|
def can_match_more(self): return not self.match_end()
|
|
|
|
|
|
class EmptyString(Node):
|
|
def match_end(self): return True
|
|
|
|
EmptyString = EmptyString()
|
|
NeverMatches = Node()
|
|
|
|
|
|
class CharacterNode(Node):
|
|
def __init__(self, char, nxt):
|
|
self.char = char
|
|
self.nxt = nxt
|
|
|
|
def derive(self, char):
|
|
if char == self.char:
|
|
return self.nxt
|
|
return NeverMatches
|
|
|
|
|
|
def new_alternation(alts):
|
|
_alts = [alt for alt in alts if alt != NeverMatches]
|
|
altsl = len(_alts)
|
|
|
|
if altsl == 0: return NeverMatches
|
|
if altsl == 1: return _alts[0]
|
|
|
|
return AlternationNode(_alts)
|
|
|
|
|
|
class AlternationNode(Node):
|
|
def __init__(self, alts):
|
|
self.alts = alts
|
|
|
|
def derive(self, char):
|
|
return new_alternation([alt.derive(char) for alt in self.alts])
|
|
|
|
def match_end(self):
|
|
return any(alt.match_end() for alt in self.alts)
|
|
|
|
def can_match_more(self):
|
|
return any(alt.can_match_more() for alt in self.alts)
|
|
|
|
|
|
class AnyNode(Node):
|
|
def __init__(self, nxt):
|
|
self.nxt = nxt
|
|
|
|
def derive(self, char):
|
|
return self.nxt
|
|
|
|
|
|
class RepetitionNode(Node):
|
|
def __init__(self, nxt):
|
|
self.head = NeverMatches
|
|
self.nxt = nxt
|
|
|
|
def derive(self, char):
|
|
return new_alternation([self.head.derive(char), self.nxt.derive(char)])
|
|
|
|
def match_end(self): return self.nxt.match_end()
|
|
def can_match_more(self): return True
|
|
|
|
|
|
class Or:
|
|
def __init__(self, alts):
|
|
self.alts = alts
|
|
|
|
|
|
class ZeroOrMore:
|
|
def __init__(self, rep):
|
|
self.rep = rep
|
|
|
|
|
|
class Any:
|
|
pass
|
|
|
|
|
|
def compile_str(s, tail=None):
|
|
if tail is None:
|
|
tail = EmptyString
|
|
|
|
for char in reversed(s):
|
|
tail = CharacterNode(char, tail)
|
|
return tail
|
|
|
|
|
|
def compile_list(s, tail=None):
|
|
if tail is None:
|
|
tail = EmptyString
|
|
|
|
for char in reversed(s):
|
|
tail = compile(char, tail)
|
|
return tail
|
|
|
|
|
|
def compile_or(or_, tail=None):
|
|
if tail is None:
|
|
tail = EmptyString
|
|
|
|
return new_alternation([compile(alt, tail) for alt in or_.alts])
|
|
|
|
|
|
def compile_zero_or_more(zero_or_more, tail=None):
|
|
if tail is None:
|
|
tail = EmptyString
|
|
|
|
repetition = RepetitionNode(tail)
|
|
contents = compile(zero_or_more.rep, repetition)
|
|
repetition.head = contents
|
|
return repetition
|
|
|
|
|
|
def compile_any(tail=None):
|
|
if tail is None:
|
|
tail = EmptyString
|
|
|
|
return AnyNode(tail)
|
|
|
|
|
|
def compile(expr, tail=None):
|
|
if tail is None:
|
|
tail = EmptyString
|
|
|
|
if type(expr) is Or: return compile_or(expr, tail)
|
|
if type(expr) is ZeroOrMore: return compile_zero_or_more(expr, tail)
|
|
if type(expr) is Any: return compile_any(tail)
|
|
if type(expr) is str: return compile_str(expr, tail)
|
|
if type(expr) is list: return compile_list(expr, tail)
|
|
raise TypeError("{} is not a compilable type.".format(type(expr)))
|
|
|
|
|
|
class RE:
|
|
def __init__(self, reg):
|
|
self.start = compile(reg)
|
|
|
|
def match(self, s):
|
|
state = self.start
|
|
ls = len(s)
|
|
|
|
if (not ls) and state is EmptyString: return True
|
|
|
|
for i, char in enumerate(s):
|
|
state = state.derive(char)
|
|
|
|
if state.match_end() and i == ls-1: return True
|
|
if state.match_end() and not state.can_match_more : return False
|
|
|
|
return False
|