diff --git a/README.md b/README.md index ec473b3..f6c2920 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,6 @@ package. ## Status The package exposes almost all functionality of the `fst` crate, except for: -- Combining the results of slicing, `search` and `search_re` with set operations - Using raw transducers @@ -83,6 +82,24 @@ m = Map.from_iter( file_iterator('/your/input/file/'), '/your/mmapped/output.fst # re-open a file you built previously with from_iter() m = Map(path='/path/to/existing.fst') + +# slicing multiple sets efficiently +a = Set.from_iter(["bar", "foo"]) +b = Set.from_iter(["baz", "foo"]) +list(UnionSet(a, b)['ba':'bb']) +['bar', 'baz'] + +# searching multiple sets efficiently +a = Set.from_iter(["bar", "foo"]) +b = Set.from_iter(["baz", "foo"]) +list(UnionSet(a, b).search('ba', 1) +['bar', 'baz'] + +# searching multiple sets with a regex efficiently +a = Set.from_iter(["bar", "foo"]) +b = Set.from_iter(["baz", "foo"]) +list(UnionSet(a, b).search_re(r'b\w{2}') +['bar', 'baz'] ``` diff --git a/rust/rust_fst.h b/rust/rust_fst.h index 401a94e..116bfbc 100644 --- a/rust/rust_fst.h +++ b/rust/rust_fst.h @@ -64,6 +64,10 @@ SetStream* fst_set_stream(Set*); SetLevStream* fst_set_levsearch(Set*, Levenshtein*); SetRegexStream* fst_set_regexsearch(Set*, Regex*); SetOpBuilder* fst_set_make_opbuilder(Set*); +SetOpBuilder* fst_set_make_opbuilder_streambuilder(SetStreamBuilder*); +SetOpBuilder* fst_set_make_opbuilder_levstream(SetLevStream*); +SetOpBuilder* fst_set_make_opbuilder_regexstream(SetRegexStream*); +SetOpBuilder* fst_set_make_opbuilder_union(SetUnion*); void fst_set_free(Set*); char* fst_set_stream_next(SetStream*); @@ -76,6 +80,10 @@ char* fst_set_regexstream_next(SetRegexStream*); void fst_set_regexstream_free(SetRegexStream*); void fst_set_opbuilder_push(SetOpBuilder*, Set*); +void fst_set_opbuilder_push_levstream(SetOpBuilder*, SetLevStream*); +void fst_set_opbuilder_push_regexstream(SetOpBuilder*, SetRegexStream*); +void fst_set_opbuilder_push_streambuilder(SetOpBuilder*, SetStreamBuilder*); +void fst_set_opbuilder_push_union(SetOpBuilder*, SetUnion*); void fst_set_opbuilder_free(SetOpBuilder*); SetUnion* fst_set_opbuilder_union(SetOpBuilder*); SetIntersection* fst_set_opbuilder_intersection(SetOpBuilder*); @@ -97,6 +105,8 @@ void fst_set_symmetricdifference_free(SetSymmetricDifference*); SetStreamBuilder* fst_set_streambuilder_new(Set*); SetStreamBuilder* fst_set_streambuilder_add_ge(SetStreamBuilder*, char*); +SetStreamBuilder* fst_set_streambuilder_add_gt(SetStreamBuilder*, char*); +SetStreamBuilder* fst_set_streambuilder_add_le(SetStreamBuilder*, char*); SetStreamBuilder* fst_set_streambuilder_add_lt(SetStreamBuilder*, char*); SetStream* fst_set_streambuilder_finish(SetStreamBuilder*); diff --git a/rust/src/set.rs b/rust/src/set.rs index bd28aa1..d3d2c3d 100644 --- a/rust/src/set.rs +++ b/rust/src/set.rs @@ -146,6 +146,38 @@ pub extern "C" fn fst_set_make_opbuilder(ptr: *mut Set) -> *mut set::OpBuilder<' } make_free_fn!(fst_set_opbuilder_free, *mut set::OpBuilder); +#[no_mangle] +pub extern "C" fn fst_set_make_opbuilder_levstream(ptr: *mut SetLevStream) -> *mut set::OpBuilder<'static> { + let sls = val_from_ptr!(ptr); + let mut ob = set::OpBuilder::new(); + ob.push(sls.into_stream()); + to_raw_ptr(ob) +} + +#[no_mangle] +pub extern "C" fn fst_set_make_opbuilder_regexstream(ptr: *mut SetRegexStream) -> *mut set::OpBuilder<'static> { + let srs = val_from_ptr!(ptr); + let mut ob = set::OpBuilder::new(); + ob.push(srs.into_stream()); + to_raw_ptr(ob) +} + +#[no_mangle] +pub extern "C" fn fst_set_make_opbuilder_streambuilder(ptr: *mut set::StreamBuilder<'static>) -> *mut set::OpBuilder<'static> { + let sb = val_from_ptr!(ptr); + let mut ob = set::OpBuilder::new(); + ob.push(sb.into_stream()); + to_raw_ptr(ob) +} + +#[no_mangle] +pub extern "C" fn fst_set_make_opbuilder_union(ptr: *mut set::Union<'static>) -> *mut set::OpBuilder<'static> { + let union = val_from_ptr!(ptr); + let mut ob = set::OpBuilder::new(); + ob.push(union.into_stream()); + to_raw_ptr(ob) +} + #[no_mangle] pub extern "C" fn fst_set_opbuilder_push(ptr: *mut set::OpBuilder, set_ptr: *mut Set) { let set = ref_from_ptr!(set_ptr); @@ -153,6 +185,34 @@ pub extern "C" fn fst_set_opbuilder_push(ptr: *mut set::OpBuilder, set_ptr: *mut ob.push(set); } +#[no_mangle] +pub extern "C" fn fst_set_opbuilder_push_levstream(ptr: *mut set::OpBuilder<'static>, sls_ptr: *mut SetLevStream) { + let sls = val_from_ptr!(sls_ptr); + let ob = mutref_from_ptr!(ptr); + ob.push(sls.into_stream()); +} + +#[no_mangle] +pub extern "C" fn fst_set_opbuilder_push_regexstream(ptr: *mut set::OpBuilder<'static>, srs_ptr: *mut SetRegexStream) { + let srs = val_from_ptr!(srs_ptr); + let ob = mutref_from_ptr!(ptr); + ob.push(srs.into_stream()); +} + +#[no_mangle] +pub extern "C" fn fst_set_opbuilder_push_streambuilder(ptr: *mut set::OpBuilder<'static>, sb_ptr: *mut set::StreamBuilder<'static>) { + let sb = val_from_ptr!(sb_ptr); + let ob = mutref_from_ptr!(ptr); + ob.push(sb.into_stream()); +} + +#[no_mangle] +pub extern "C" fn fst_set_opbuilder_push_union(ptr: *mut set::OpBuilder<'static>, union_ptr: *mut set::Union<'static>) { + let union = val_from_ptr!(union_ptr); + let ob = mutref_from_ptr!(ptr); + ob.push(union.into_stream()); +} + #[no_mangle] pub extern "C" fn fst_set_opbuilder_union(ptr: *mut set::OpBuilder) -> *mut set::Union { @@ -205,6 +265,22 @@ pub extern "C" fn fst_set_streambuilder_add_ge(ptr: *mut set::StreamBuilder<'sta to_raw_ptr(sb.ge(cstr_to_str(c_bound))) } +#[no_mangle] +pub extern "C" fn fst_set_streambuilder_add_gt(ptr: *mut set::StreamBuilder<'static>, + c_bound: *mut libc::c_char) + -> *mut set::StreamBuilder<'static> { + let sb = val_from_ptr!(ptr); + to_raw_ptr(sb.gt(cstr_to_str(c_bound))) +} + +#[no_mangle] +pub extern "C" fn fst_set_streambuilder_add_le(ptr: *mut set::StreamBuilder<'static>, + c_bound: *mut libc::c_char) + -> *mut set::StreamBuilder<'static> { + let sb = val_from_ptr!(ptr); + to_raw_ptr(sb.le(cstr_to_str(c_bound))) +} + #[no_mangle] pub extern "C" fn fst_set_streambuilder_add_lt(ptr: *mut set::StreamBuilder<'static>, c_bound: *mut libc::c_char) diff --git a/rust_fst/__init__.py b/rust_fst/__init__.py index 5dd3cfb..04de867 100644 --- a/rust_fst/__init__.py +++ b/rust_fst/__init__.py @@ -1,4 +1,4 @@ -from .set import Set +from .set import Set, UnionSet from .map import Map -__all__ = ["Set", "Map"] +__all__ = ["Set", "UnionSet", "Map"] diff --git a/rust_fst/set.py b/rust_fst/set.py index cc08541..059ed66 100644 --- a/rust_fst/set.py +++ b/rust_fst/set.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +from enum import Enum from .common import KeyStreamIterator from .lib import ffi, lib, checked_call @@ -55,14 +56,79 @@ def get_set(self): return Set(None, _pointer=self._set_ptr) +class OpBuilderInputType(Enum): + SET = 1 + STREAM_BUILDER = 2 + UNION = 3 + SEARCH = 4 + SEARCH_RE = 5 + + +def _build_levsearch(fst, term, max_dist): + lev_ptr = checked_call( + lib.fst_levenshtein_new, + fst._ctx, + ffi.new("char[]", term.encode('utf8')), + max_dist) + return lib.fst_set_levsearch(fst._ptr, lev_ptr) + + +def _build_research(fst, pattern): + re_ptr = checked_call( + lib.fst_regex_new, fst._ctx, + ffi.new("char[]", pattern.encode('utf8'))) + return lib.fst_set_regexsearch(fst._ptr, re_ptr) + + class OpBuilder(object): - def __init__(self, set_ptr): + + _BUILDERS = { + OpBuilderInputType.SET: lib.fst_set_make_opbuilder, + OpBuilderInputType.STREAM_BUILDER: lib.fst_set_make_opbuilder_streambuilder, + OpBuilderInputType.UNION: lib.fst_set_make_opbuilder_union, + OpBuilderInputType.SEARCH: lib.fst_set_make_opbuilder_levstream, + OpBuilderInputType.SEARCH_RE: lib.fst_set_make_opbuilder_regexstream, + } + _PUSHERS = { + OpBuilderInputType.SET: lib.fst_set_opbuilder_push, + OpBuilderInputType.STREAM_BUILDER: lib.fst_set_opbuilder_push_streambuilder, + OpBuilderInputType.UNION: lib.fst_set_opbuilder_push_union, + OpBuilderInputType.SEARCH: lib.fst_set_opbuilder_push_levstream, + OpBuilderInputType.SEARCH_RE: lib.fst_set_opbuilder_push_regexstream, + } + + @classmethod + def from_search(cls, fst, term, max_dist): + stream_ptr = _build_levsearch(fst, term, max_dist) + opbuilder = OpBuilder(stream_ptr, + input_type=OpBuilderInputType.SEARCH) + return opbuilder + + @classmethod + def from_search_re(cls, fst, pattern): + stream_ptr = _build_research(fst, pattern) + opbuilder = OpBuilder(stream_ptr, + input_type=OpBuilderInputType.SEARCH_RE) + return opbuilder + + @classmethod + def from_slice(cls, set_ptr, s): + sb = StreamBuilder.from_slice(set_ptr, s) + opbuilder = OpBuilder(sb._ptr, + input_type=OpBuilderInputType.STREAM_BUILDER) + return opbuilder + + def __init__(self, ptr, input_type=OpBuilderInputType.SET): + if input_type not in self._BUILDERS: + raise ValueError( + "input_type must be a member of OpBuilderInputType.") + self._input_type = input_type # NOTE: No need for `ffi.gc`, since the struct will be free'd # once we call union/intersection/difference - self._ptr = lib.fst_set_make_opbuilder(set_ptr) + self._ptr = OpBuilder._BUILDERS[self._input_type](ptr) - def push(self, set_ptr): - lib.fst_set_opbuilder_push(self._ptr, set_ptr) + def push(self, ptr): + OpBuilder._PUSHERS[self._input_type](self._ptr, ptr) def union(self): stream_ptr = lib.fst_set_opbuilder_union(self._ptr) @@ -86,6 +152,44 @@ def symmetric_difference(self): lib.fst_set_symmetricdifference_free) +class StreamBuilder(object): + + @classmethod + def from_slice(cls, set_ptr, slice_bounds): + sb = StreamBuilder(set_ptr) + if slice_bounds.start: + sb.ge(slice_bounds.start) + if slice_bounds.stop: + sb.lt(slice_bounds.stop) + return sb + + def __init__(self, set_ptr): + # NOTE: No need for `ffi.gc`, since the struct will be free'd + # once we call union/intersection/difference + self._ptr = lib.fst_set_streambuilder_new(set_ptr) + + def finish(self): + stream_ptr = lib.fst_set_streambuilder_finish(self._ptr) + return KeyStreamIterator(stream_ptr, lib.fst_set_stream_next, + lib.fst_set_stream_free) + + def ge(self, bound): + c_start = ffi.new("char[]", bound.encode('utf8')) + self._ptr = lib.fst_set_streambuilder_add_ge(self._ptr, c_start) + + def gt(self, bound): + c_start = ffi.new("char[]", bound.encode('utf8')) + self._ptr = lib.fst_set_streambuilder_add_gt(self._ptr, c_start) + + def le(self, bound): + c_end = ffi.new("char[]", bound.encode('utf8')) + self._ptr = lib.fst_set_streambuilder_add_le(self._ptr, c_end) + + def lt(self, bound): + c_end = ffi.new("char[]", bound.encode('utf8')) + self._ptr = lib.fst_set_streambuilder_add_lt(self._ptr, c_end) + + class Set(object): """ An immutable ordered string set backed by a finite state transducer. @@ -203,19 +307,11 @@ def __getitem__(self, s): if s.start and s.stop and s.start > s.stop: raise ValueError( "Start key must be lexicographically smaller than stop.") - sb_ptr = lib.fst_set_streambuilder_new(self._ptr) - if s.start: - c_start = ffi.new("char[]", s.start.encode('utf8')) - sb_ptr = lib.fst_set_streambuilder_add_ge(sb_ptr, c_start) - if s.stop: - c_stop = ffi.new("char[]", s.stop.encode('utf8')) - sb_ptr = lib.fst_set_streambuilder_add_lt(sb_ptr, c_stop) - stream_ptr = lib.fst_set_streambuilder_finish(sb_ptr) - return KeyStreamIterator(stream_ptr, lib.fst_set_stream_next, - lib.fst_set_stream_free) + sb = StreamBuilder.from_slice(self._ptr, s) + return sb.finish() def _make_opbuilder(self, *others): - opbuilder = OpBuilder(self._ptr) + opbuilder = OpBuilder(self._ptr, input_type=OpBuilderInputType.SET) for oth in others: opbuilder.push(oth._ptr) return opbuilder @@ -333,3 +429,171 @@ def search(self, term, max_dist): return KeyStreamIterator(stream_ptr, lib.fst_set_levstream_next, lib.fst_set_levstream_free, lev_ptr, lib.fst_levenshtein_free) + + +class UnionSet(object): + """ A collection of Set objects that offer efficient operations across all + members. + """ + def __init__(self, *sets): + self.sets = list(sets) + + def __contains__(self, val): + """ Check if the set contains the value. """ + return any([ + lib.fst_set_contains(fst._ptr, + ffi.new("char[]", + val.encode('utf8'))) + for fst in self.sets + ]) + + def __getitem__(self, s): + """ Get an iterator over a range of set contents. + + Start and stop indices of the slice must be unicode strings. + + .. important:: + Slicing follows the semantics for numerical indices, i.e. the + `stop` value is **exclusive**. For example, given the set + `s = Set.from_iter(["bar", "baz", "foo", "moo"])`, `s['b': 'f']` + will only return `"bar"` and `"baz"`. + + :param s: A slice that specifies the range of the set to retrieve + :type s: :py:class:`slice` + """ + if not isinstance(s, slice): + raise ValueError( + "Value must be a string slice (e.g. `['foo':]`)") + if s.start and s.stop and s.start > s.stop: + raise ValueError( + "Start key must be lexicographically smaller than stop.") + + if not self.sets: + return + opbuilder = OpBuilder.from_slice(self.sets[0]._ptr, s) + streams = [] + for fst in self.sets[1:]: + sb = StreamBuilder.from_slice(fst._ptr, s) + streams.append(sb) + for sb in streams: + opbuilder.push(sb._ptr) + return opbuilder.union() + + def __iter__(self): + """ Get an iterator over all keys in all sets in lexicographical order. + """ + if not self.sets: + return + opbuilder = OpBuilder(self.sets[0]._ptr, + input_type=OpBuilderInputType.SET) + for fst in self.sets[1:]: + opbuilder.push(fst._ptr) + return opbuilder.union() + + def _make_opbuilder(self, *others): + others = list(others) + if not others: + raise ValueError( + "Must have at least one set to compare against.") + if not self.sets: + return + our_opbuilder = OpBuilder(self.sets[0]._ptr, + input_type=OpBuilderInputType.SET) + for fst in self.sets[1:]: + our_opbuilder.push(fst._ptr) + our_stream = lib.fst_set_opbuilder_union(our_opbuilder._ptr) + + their_opbuilder = OpBuilder(others.pop()._ptr, + input_type=OpBuilderInputType.SET) + for fst in others: + their_opbuilder.push(fst._ptr) + their_stream = lib.fst_set_opbuilder_union(their_opbuilder._ptr) + + opbuilder = OpBuilder(our_stream, input_type=OpBuilderInputType.UNION) + opbuilder.push(their_stream) + return opbuilder + + def difference(self, *others): + """ Get an iterator over the keys in the difference of this set and + others. + + :param others: List of :py:class:`Set` objects + :returns: Iterator over all keys that exists in this set, but in + none of the other sets, in lexicographical order + """ + return self._make_opbuilder(*others).difference() + + def intersection(self, *others): + """ Get an iterator over the keys in the intersection of this set and + others. + + :param others: List of :py:class:`Set` objects + :returns: Iterator over all keys that exists in all of the passed + sets in lexicographical order + """ + return self._make_opbuilder(*others).intersection() + + def search(self, term, max_dist): + """ Search the set with a Levenshtein automaton. + + :param term: The search term + :param max_dist: The maximum edit distance for search results + :returns: Iterator over matching values in the set + :rtype: :py:class:`KeyStreamIterator` + """ + if not self.sets: + return + opbuilder = OpBuilder.from_search(self.sets[0], term, max_dist) + for fst in self.sets[1:]: + opbuilder.push(_build_levsearch(fst, term, max_dist)) + return opbuilder.union() + + def search_re(self, pattern): + """ Search the set with a regular expression. + + Note that the regular expression syntax is not Python's, but the one + supported by the `regex` Rust crate, which is almost identical + to the engine of the RE2 engine. + + For a documentation of the syntax, see: + http://doc.rust-lang.org/regex/regex/index.html#syntax + + Due to limitations of the underlying FST, only a subset of this syntax + is supported. Most notably absent are: + + * Lazy quantifiers (``r'*?'``, ``r'+?'``) + * Word boundaries (``r'\\b'``) + * Other zero-width assertions (``r'^'``, ``r'$'``) + + For background on these limitations, consult the documentation of + the Rust crate: http://burntsushi.net/rustdoc/fst/struct.Regex.html + + :param pattern: A regular expression + :returns: An iterator over all matching keys in the set + :rtype: :py:class:`KeyStreamIterator` + """ + if not self.sets: + return + opbuilder = OpBuilder.from_search_re(self.sets[0], pattern) + for fst in self.sets[1:]: + opbuilder.push(_build_research(fst, pattern)) + return opbuilder.union() + + def symmetric_difference(self, *others): + """ Get an iterator over the keys in the symmetric difference of this + set and others. + + :param others: List of :py:class:`Set` objects + :returns: Iterator over all keys that exists in only one of the + sets in lexicographical order + """ + return self._make_opbuilder(*others).symmetric_difference() + + def union(self, *others): + """ Get an iterator over the keys in the union of this set and others. + + :param others: List of :py:class:`Set` objects + :returns: Iterator over all keys in all sets in lexicographical + order + """ + return self._make_opbuilder(*others).union() diff --git a/tests/test_set.py b/tests/test_set.py index 509b412..af9fc17 100644 --- a/tests/test_set.py +++ b/tests/test_set.py @@ -2,10 +2,11 @@ import pytest import rust_fst.lib as lib -from rust_fst import Set +from rust_fst import Set, UnionSet TEST_KEYS = [u"möö", "bar", "baz", "foo"] +TEST_KEYS2 = ["bing", "baz", "bap", "foo"] def do_build(path, keys=TEST_KEYS, sorted_=True): @@ -21,6 +22,17 @@ def fst_set(tmpdir): return Set(fst_path) +@pytest.fixture +def fst_unionset(tmpdir): + fst_path1 = str(tmpdir.join('test1.fst')) + fst_path2 = str(tmpdir.join('test2.fst')) + do_build(fst_path1, keys=TEST_KEYS) + do_build(fst_path2, keys=TEST_KEYS2) + a = Set(fst_path1) + b = Set(fst_path2) + return UnionSet(a, b) + + def test_build(tmpdir): fst_path = tmpdir.join('test.fst') do_build(str(fst_path)) @@ -147,3 +159,61 @@ def test_range(fst_set): fst_set['c':'a'] with pytest.raises(ValueError): fst_set['c'] + + +def test_unionset_contains(fst_unionset): + for key in TEST_KEYS+TEST_KEYS2: + assert key in fst_unionset + + +def test_unionset_difference(): + a = Set.from_iter(["bar", "foo"]) + b = Set.from_iter(["baz", "foo"]) + c = Set.from_iter(["bonk", "foo"]) + assert list(UnionSet(a, b).difference(c)) == ["bar", "baz"] + + +def test_unionset_intersection(): + a = Set.from_iter(["bar", "foo"]) + b = Set.from_iter(["baz", "foo"]) + c = Set.from_iter(["bonk", "foo"]) + assert list(UnionSet(a, b).intersection(c)) == ["foo"] + + +def test_unionset_iter(fst_unionset): + stored_keys = list(fst_unionset) + assert stored_keys == sorted(set(TEST_KEYS+TEST_KEYS2)) + + +def test_unionset_range(fst_unionset): + assert list(fst_unionset['f':]) == ['foo', u'möö'] + assert list(fst_unionset[:'m']) == ['bap', 'bar', 'baz', 'bing', 'foo'] + assert list(fst_unionset['baz':'m']) == ['baz', 'bing', 'foo'] + with pytest.raises(ValueError): + fst_unionset['c':'a'] + with pytest.raises(ValueError): + fst_unionset['c'] + + +def test_unionset_search(fst_unionset): + matches = list(fst_unionset.search("bam", 1)) + assert matches == ["bap", "bar", "baz"] + + +def test_unionset_search_re(fst_unionset): + matches = list(fst_unionset.search_re(r'ba.*')) + assert matches == ["bap", "bar", "baz"] + + +def test_unionset_symmetric_difference(): + a = Set.from_iter(["bar", "foo"]) + b = Set.from_iter(["baz", "foo"]) + c = Set.from_iter(["bonk", "foo"]) + assert list(UnionSet(a, b).symmetric_difference(c)) == ["bar", "baz", "bonk"] + + +def test_unionset_union(): + a = Set.from_iter(["bar", "foo"]) + b = Set.from_iter(["baz", "foo"]) + c = Set.from_iter(["bonk", "foo"]) + assert list(UnionSet(a, b).union(c)) == ["bar", "baz", "bonk", "foo"]