# Written by Andrew Dalke # See http://dalkescientific.com/writings/diary/archive/2012/06/10/inverted_index_library.html from collections import defaultdict import re import array import time import glob import os import psutil get_memory_info = psutil.Process(os.getpid()).get_memory_info # Line looks like # #bit-7=19184 cccc bitdef_pat = re.compile(r"#bit-(\d+)=\d+ (.*)") bitno_pat = re.compile(r"(\d+)(?::\d+)?,?") class Intern(dict): def __missing__(self, key): self[key] = key return key class InvertedIndex(object): def __init__(self): self.pattern_to_id = {} self.inverted_indices = defaultdict(set) self._intern = Intern() def get_pattern_id(self, pattern): try: return self.pattern_to_id[pattern] except KeyError: n = len(self.pattern_to_id) self.pattern_to_id[pattern] = n return n def add_record(self, id, features): id = self._intern[id] for feature in features: feature = self._intern[feature] self.inverted_indices[feature].add(id) def search(self, features): empty = () terms = [self.inverted_indices.get(feature, empty) for feature in features] terms.sort(key=len) if not terms[0]: return set() return set.intersection(*terms) class InvertedIndexSingleton(object): def __init__(self): self.singletons = {} self.pattern_to_id = {} self.inverted_indices = defaultdict(set) self._intern = Intern() def get_pattern_id(self, pattern): try: return self.pattern_to_id[pattern] except KeyError: n = len(self.pattern_to_id) self.pattern_to_id[pattern] = n return n def add_record(self, id, features): id = self._intern[id] for feature in features: feature = self._intern[feature] if feature in self.singletons: self.inverted_indices[feature].add(self.singletons[feature]) self.inverted_indices[feature].add(id) del self.singletons[feature] elif feature not in self.inverted_indices: self.singletons[feature] = id else: self.inverted_indices[feature].add(id) def search(self, features): assert features singleton = None need_intersect = [] for feature in features: x = self.singletons.get(feature, None) if x is not None: if singleton is None: singleton = x else: if singleton is not x: return set() else: inv = self.inverted_indices.get(feature, None) if inv is None: return set() need_intersect.append(inv) if singleton is not None: for inv in need_intersect: if singleton not in inv: return set() return set([singleton]) if not need_intersect: print features 1/0 need_intersect.sort(key=len) return set.intersection(*need_intersect) _test = array.array("I", (0,)) try: _test[0] = 2**31+100 array_code = "I" except OverflowError: array_code = "L" #print "Using", array_code def make_unsigned_array(): return array.array(array_code, ()) class InvertedIndexArray(object): def __init__(self): self.singletons = {} self.pattern_to_id = {} self.inverted_indices = defaultdict(make_unsigned_array) def get_pattern_id(self, pattern): try: return self.pattern_to_id[pattern] except KeyError: n = len(self.pattern_to_id) self.pattern_to_id[pattern] = n return n def add_record(self, id, features): for feature in features: if feature in self.singletons: self.inverted_indices[feature].append(self.singletons[feature]) self.inverted_indices[feature].append(id) del self.singletons[feature] elif feature not in self.inverted_indices: self.singletons[feature] = id else: self.inverted_indices[feature].append(id) def search(self, features): assert features singleton = None need_intersect = [] for feature in features: x = self.singletons.get(feature, None) if x is not None: if singleton is None: singleton = x else: if singleton is not x: return set() else: inv = self.inverted_indices.get(feature, None) if inv is None: return set() need_intersect.append(inv) if singleton is not None: for inv in need_intersect: if singleton not in inv: return set() return set([singleton]) if not need_intersect: print features 1/0 need_intersect.sort(key=len) return set.intersection(*map(set, need_intersect)) def add_dataset(index, infile): it = iter(infile) first_line = next(it) assert first_line == "#FPC1\n", first_line bitno_to_pattern_id = {} for line in it: if line[:1] != "#": break m = bitdef_pat.match(line) assert m is not None bitno, pattern = m.groups() pattern_id = index.get_pattern_id(pattern) bitno_to_pattern_id[bitno] = pattern_id else: # No data? Interesting. Should be the case... raise AssertionError("premature end of file") return # These lines look like: # 36888:3,46979,53250,53911,55337,57024,92899 11299999 for line in it: bitno_string, _, id = line.rpartition("\t") id = int(id) features = [bitno_to_pattern_id[bitno] for bitno in bitno_pat.findall(bitno_string)] index.add_record(id, features) class GetQueries(object): def __init__(self, index, every=100): assert every >= 1 self.index = index self.records = [] self.count = 0 self.every = every - 1 def get_pattern_id(self, pattern): return self.index.get_pattern_id(pattern) def add_record(self, id, features): if self.count == self.every: self.count = 0 self.records.append( (id, features) ) else: self.count += 1 # Done loading in 63.6s meminfo(rss=2112045056, vms=4605886464) meminfo(rss=7446528, vms=2500407296) 2 105 479 168 # Counts 4048 0.11482000351 #index = InvertedIndex() # Done loading in 85.4s meminfo(rss=2101399552, vms=4595073024) meminfo(rss=7446528, vms=2500407296) 2 094 665 728 # Counts 4048 0.132526874542 index = InvertedIndexSingleton() # Done loading in 58.9s meminfo(rss=255737856, vms=2774700032) meminfo(rss=7446528, vms=2500407296) 274 292 736 # Counts 4048 215.295588017 #index = InvertedIndexArray() start_size = get_memory_info() t1 = time.time() for i, filename in enumerate(glob.glob("*.fpc")): if i == 10: query_filename = filename break print "Load", i, filename add_dataset(index, open(filename)) t2 = time.time() end_size = get_memory_info() print "Done loading in %.1fs" % (t2-t1), end_size, start_size, end_size.vms - start_size.vms #raise SystemExit # semi-faked; I add the features even if not needed # 23014 queries in the data set; most are empty queries = GetQueries(index) add_dataset(queries, open(query_filename)) counts = 0 t1 = time.time() for i, (id, features) in enumerate(queries.records): hits = index.search(features) counts += len(hits) if i % 100 == 99: print i+1, counts, "%.1f per second" % ((i+1)/(time.time()-t1)) t2 = time.time() print "Counts", counts, t2-t1