import sys sys.path.insert(0, "/Users/dalke/cvses/python/nondist/sandbox/rational") import Rational # arithmetic encoding of SMILES old_states_descrip = """ START -> C -> expect_any expect_any -> C12345 -> expect_any expect_any -> ( -> expect_atom expect_any -> ) -> expect_any expect_any -> -> END expect_atom -> C -> expect_any """ # Only one has 4 embedded parens # 146 with 3 or more # 9118 with 2 or more # 20013 with 1 or more # 30 with no branches states_descrip = """ START -> C -> expect_any expect_any -> C12345 -> expect_any expect_any -> ( -> expect_atom_close1 expect_any -> ) -> expect_any expect_any -> -> END expect_atom -> C -> expect_any expect_any_close1 -> C12345 -> expect_any_close1 expect_any_close1 -> ( -> expect_atom_close2 expect_any_close1 -> ) -> expect_any expect_atom_close1 -> C -> expect_any_close1 expect_any_close2 -> C12345 -> expect_any_close2 expect_any_close2 -> ( -> expect_atom_close3 expect_any_close2 -> ) -> expect_any_close1 expect_atom_close2 -> C -> expect_any_close2 expect_any_close3 -> C12345 -> expect_any_close3 expect_any_close3 -> ( -> expect_atom_close4 expect_any_close3 -> ) -> expect_any_close2 expect_atom_close3 -> C -> expect_any_close3 expect_any_close4 -> C12345 -> expect_any_close4 expect_any_close4 -> ) -> expect_any_close3 expect_atom_close4 -> C -> expect_any_close4 """ def _get_ranges(table): ranges = {} rev_ranges = [] items = table.items() items.sort() if items and items[0][0] is None: items = items[1:] + [items[0]] i = 0 for c, count in items: ranges[c] = i, i+count i = i + count rev_ranges.append( (i, c) ) return ranges, rev_ranges class Probabilities: def __init__(self, table): self.table = table self.N = sum(table.values()) self.ranges, self.rev_ranges = _get_ranges(table) def find_char(self, val): assert 0 <= val < 1 val = val * self.N for x, c in self.rev_ranges: if val < x: return c raise AssertionError("out of range") class StateTable: def __init__(self): self.data = {} self.start = self["START"] self.end = self["END"] def __str__(self): names = self.data.keys() names.sort() names.remove(self.start.name) names.remove(self.end.name) names.insert(0, self.start.name) names.append(self.end.name) return "\n".join([str(self.data[name]) for name in names]) def __getitem__(self, name): try: return self.data[name] except KeyError: x = self.data[name] = State(name) return x def add_edge(self, from_name, c, to_name): from_node = self[from_name] to_node = self[to_name] from_node.add_edge(c, to_node) def get_probs(self): states = {} for name, state in self.data.items(): states[name] = Probabilities(state.counts) return states class State: def __init__(self, name): self.name = name self.edges = {} self.counts = {} def __str__(self): lines = ["%s:" % (self.name,)] for c, state in sorted(self.edges.items()): lines.append(" %r -> %s (%d)" % (c, state.name, self.counts[c])) return "\n".join(lines) def add_edge(self, c, to_node): assert c not in self.counts, (c, self.name) self.edges[c] = to_node self.counts[c] = 0 def load(infile): states = StateTable() for line in infile: line = line.strip() if not line: continue fields = line.split("->") assert len(fields) == 3 from_name = fields[0].strip() chars = fields[1].strip() to_name = fields[2].strip() if not chars: states.add_edge(from_name, None, to_name) else: for c in chars: states.add_edge(from_name, c, to_name) return states def encode(states, text): probs = states.get_probs() node = states.start state_probs = probs[node.name] minval = Rational.rational(0) maxval = Rational.rational(1) for c in list(text) + [None]: prob_range = state_probs.ranges[c] delta = Rational.rational(maxval - minval, state_probs.N) maxval = minval + prob_range[1] * delta minval = minval + prob_range[0] * delta node = node.edges[c] state_probs = probs[node.name] delta = (maxval - minval)/2 nbits = 0L while delta < 1: nbits = nbits + 1 delta = delta << 1 if nbits == 0: return 0, 0 avg = (maxval + minval)<<(nbits-1) return avg.n//avg.d, nbits def traverse(states, text): node = states.start for c in text: try: node.counts[c] += 1 except KeyError: print "Problem at", node.name, "with", c raise node = node.edges[c] node.counts[None] += 1 node = node.edges[None] assert node is states.end def decode(states, longval, nbits): val = Rational.rational(longval, 1L<>debug, i, print i, smi = line.strip() x = encode(states, smi) smi2 = decode(states, *x) print >>debug, x print x assert smi == smi2, (smi, smi2) if __name__ == "__main__": main()