import random import sys from collections import deque from datetime import datetime if 1: import random import sys from collections import deque from datetime import datetime class Node(): __slots__ = ("char", "count", "lo", "eq", "hi") def __init__(self, char): self.char = char self.count = 0 self.lo = None self.eq = None self.hi = None def delete(self): global killed return if self.lo: self.lo.delete() self.lo = None if self.eq: self.eq.delete() self.eq = None if self.hi: self.hi.delete() self.hi = None killed += 1 if killed % 100000 == 0: print(f"{killed:,} deleted") def depth(self): return 1 + max(self.lo.depth() if self.lo else 0, self.eq.depth() if self.eq else 0, self.hi.depth() if self.hi else 0) class TernarySearchTree(): """Ternary search tree that stores counts for n-grams and their subsequences. """ def __init__(self, splitchar=None): self.root = None self.splitchar = splitchar def insert(self, string): self.root = self._insert(string, self.root) def _insert(self, string, node): """Insert string at a given node. """ if not string: return node char, *rest = string if node is None: node = Node(char) if char == node.char: if not rest: node.count += 1 return node else: if rest[0] == self.splitchar: node.count += 1 node.eq = self._insert(rest, node.eq) elif char < node.char: node.lo = self._insert(string, node.lo) else: node.hi = self._insert(string, node.hi) return node def random_strings(num_strings): random.seed(2) symbols = "abcdefghijklmnopqrstuvwxyz" for i in range(num_strings): length = random.randint(5, 15) yield "".join(random.choices(symbols, k=length)) def train(stime): tree = TernarySearchTree("#") grams = deque(maxlen=4) for token in random_strings(27_000_000): grams.append(token) tree.insert("#".join(grams)) sys.stdout.write("This gets printed!\n") sys.stdout.flush() sys._debugmallocstats() print("\nbuild time", datetime.now() - stime) #print("depth", tree.root.depth()) tree.root.delete() s = datetime.now() return s def main(): b = train(datetime.now()) print("\nteardown time", datetime.now()- b) sys.stdout.write("This doesn't get printed\n") sys.stdout.flush() killed = 0 main() print(killed, "killed") sys._debugmallocstats()