diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index b4fa893a47d..aa21b727d42 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -3860,6 +3860,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter pub fn vortex_array::arrays::Filter::take(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::filter(vortex_array::ArrayView<'_, Self>, &vortex_mask::Mask) -> vortex_error::VortexResult> + pub struct vortex_array::arrays::filter::FilterData impl vortex_array::arrays::filter::FilterData @@ -3970,6 +3974,10 @@ impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Extens pub fn vortex_array::arrays::Extension::filter(vortex_array::ArrayView<'_, vortex_array::arrays::Extension>, &vortex_mask::Mask) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::filter(vortex_array::ArrayView<'_, Self>, &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Masked pub fn vortex_array::arrays::Masked::filter(vortex_array::ArrayView<'_, vortex_array::arrays::Masked>, &vortex_mask::Mask) -> vortex_error::VortexResult> @@ -7156,6 +7164,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter pub fn vortex_array::arrays::Filter::take(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::filter(vortex_array::ArrayView<'_, Self>, &vortex_mask::Mask) -> vortex_error::VortexResult> + pub struct vortex_array::arrays::FixedSizeList impl core::clone::Clone for vortex_array::arrays::FixedSizeList diff --git a/vortex-array/src/arrays/bool/compute/filter.rs b/vortex-array/src/arrays/bool/compute/filter.rs index c7b45498957..315c1db2474 100644 --- a/vortex-array/src/arrays/bool/compute/filter.rs +++ b/vortex-array/src/arrays/bool/compute/filter.rs @@ -7,7 +7,7 @@ use vortex_buffer::get_bit; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_mask::Mask; -use vortex_mask::MaskIter; +use vortex_mask::MaskValues; use crate::ArrayRef; use crate::IntoArray; @@ -17,8 +17,9 @@ use crate::arrays::BoolArray; use crate::arrays::bool::BoolArrayExt; use crate::arrays::filter::FilterReduce; -/// If the filter density is above 80%, we use slices to filter the array instead of indices. -const FILTER_SLICES_DENSITY_THRESHOLD: f64 = 0.8; +/// Below this density threshold, use the sparse path which iterates only set +/// bits in the mask. Above it, the word-level PEXT approach is faster. +const SPARSE_DENSITY_THRESHOLD: f64 = 0.05; impl FilterReduce for Bool { fn filter(array: ArrayView<'_, Bool>, mask: &Mask) -> VortexResult> { @@ -28,47 +29,276 @@ impl FilterReduce for Bool { .values() .vortex_expect("AllTrue and AllFalse are handled by filter fn"); - let buffer = match mask_values.threshold_iter(FILTER_SLICES_DENSITY_THRESHOLD) { - MaskIter::Indices(indices) => filter_indices(&array.to_bit_buffer(), indices), - MaskIter::Slices(slices) => filter_slices( - &array.to_bit_buffer(), - mask.true_count(), - slices.iter().copied(), - ), + let src = array.to_bit_buffer(); + let density = mask_values.density(); + let buffer = if density < SPARSE_DENSITY_THRESHOLD { + filter_sparse(&src, mask_values, mask.true_count()) + } else { + filter_bitbuffer_by_mask(&src, mask_values.bit_buffer(), mask.true_count()) }; Ok(Some(BoolArray::new(buffer, validity).into_array())) } } -fn filter_indices(bools: &BitBuffer, indices: &[usize]) -> BitBuffer { - let buffer = bools.inner().as_ref(); - let offset = bools.offset(); - BitBuffer::collect_bool(indices.len(), |idx| { - // Safety: - // We iterate over the slice's length. - let idx = unsafe { indices.get_unchecked(idx) } + offset; - get_bit(buffer, idx) - }) +fn filter_sparse(src: &BitBuffer, mask_values: &MaskValues, true_count: usize) -> BitBuffer { + if let Some(slices) = mask_values.cached_slices() { + filter_slices(src, true_count, slices.iter().copied()) + } else if let Some(indices) = mask_values.cached_indices() { + let buffer = src.inner().as_ref(); + let offset = src.offset(); + BitBuffer::collect_bool(indices.len(), |idx| { + // SAFETY: `collect_bool` calls the closure exactly `indices.len()` times. + let idx = unsafe { *indices.get_unchecked(idx) }; + get_bit(buffer, offset + idx) + }) + } else { + filter_set_bits(src, mask_values.bit_buffer(), true_count) + } } fn filter_slices( - buffer: &BitBuffer, - indices_len: usize, + src: &BitBuffer, + output_len: usize, slices: impl Iterator, ) -> BitBuffer { - let mut builder = BitBufferMut::with_capacity(indices_len); + let mut builder = BitBufferMut::with_capacity(output_len); for (start, end) in slices { - // TODO(ngates): we probably want a borrowed slice for things like this. - builder.append_buffer(&buffer.slice(start..end)); + builder.append_buffer(&src.slice(start..end)); } builder.freeze() } -#[cfg(test)] -mod test { +fn filter_set_bits(src: &BitBuffer, mask_buf: &BitBuffer, true_count: usize) -> BitBuffer { + let buffer = src.inner().as_ref(); + let offset = src.offset(); + let mut indices = mask_buf.set_indices(); + BitBuffer::collect_bool(true_count, |_| { + // SAFETY: the iterator yields exactly true_count indices. + let idx = unsafe { indices.next().unwrap_unchecked() }; + get_bit(buffer, offset + idx) + }) +} + +/// Extract bits from `src` where corresponding bits in `mask_buf` are set. +/// +/// Uses a software PEXT (parallel bit extract) to compact selected bits from +/// each 64-bit word, with a u128 accumulator to simplify overflow handling. +/// Fast paths skip PEXT entirely for all-ones and all-zeros mask words. +pub fn filter_bitbuffer_by_mask( + src: &BitBuffer, + mask_buf: &BitBuffer, + true_count: usize, +) -> BitBuffer { + #[cfg(target_arch = "x86_64")] + { + if std::arch::is_x86_feature_detected!("bmi2") { + // SAFETY: BMI2 confirmed available; the inner function is compiled with BMI2. + return unsafe { filter_pext_bmi2(src, mask_buf, true_count) }; + } + } + filter_pext_fallback(src, mask_buf, true_count) +} + +/// BMI2-native filter: entire function compiled with BMI2+POPCNT enabled. +/// +/// The compiler generates PEXT for bit extraction, SHLX/SHRX for flag-free +/// shifts, and POPCNT for population count — no runtime feature checks in +/// the hot loop. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "bmi2,popcnt")] +unsafe fn filter_pext_bmi2(src: &BitBuffer, mask_buf: &BitBuffer, true_count: usize) -> BitBuffer { + use std::arch::x86_64::_pext_u64; + + filter_inner(src, mask_buf, true_count, |src, mask| _pext_u64(src, mask)) +} + +/// Software fallback filter using byte-LUT PEXT. +fn filter_pext_fallback(src: &BitBuffer, mask_buf: &BitBuffer, true_count: usize) -> BitBuffer { + filter_inner(src, mask_buf, true_count, pext_fallback) +} + +/// Core filter loop parameterized by the PEXT implementation. +/// +/// Extracted so the same logic is shared between the software and hardware paths. +/// Uses raw pointer writes instead of Vec::push to eliminate bounds checks +/// in the hot loop — we know the exact output size from true_count. +#[inline(always)] +#[allow(clippy::cast_possible_truncation)] +fn filter_inner( + src: &BitBuffer, + mask_buf: &BitBuffer, + true_count: usize, + pext_fn: impl Fn(u64, u64) -> u64, +) -> BitBuffer { + debug_assert_eq!(src.len(), mask_buf.len()); + + let src_chunks = src.chunks(); + let mask_chunks = mask_buf.chunks(); + + let out_u64s = true_count.div_ceil(64); + let mut output: Vec = Vec::with_capacity(out_u64s + 1); + let out_ptr = output.as_mut_ptr(); + let mut out_idx: usize = 0; + // u128 accumulator: overflow naturally held in upper 64 bits, eliminating + // the tricky `extracted >> (popcount - accum_bits)` re-derivation. Just + // flush low 64 bits and shift down when full. + let mut accum: u128 = 0; + let mut accum_bits: u32 = 0; + + for (src_word, mask_word) in src_chunks.iter().zip(mask_chunks.iter()) { + if mask_word == u64::MAX { + // All 64 bits selected — copy source word directly, no PEXT needed. + accum |= (src_word as u128) << accum_bits; + accum_bits += 64; + if accum_bits >= 64 { + unsafe { out_ptr.add(out_idx).write(accum as u64) }; + out_idx += 1; + accum >>= 64; + accum_bits -= 64; + } + continue; + } + + let popcount = mask_word.count_ones(); + if popcount == 0 { + continue; + } + + let extracted = pext_fn(src_word, mask_word); + + accum |= (extracted as u128) << accum_bits; + accum_bits += popcount; + + if accum_bits >= 64 { + unsafe { out_ptr.add(out_idx).write(accum as u64) }; + out_idx += 1; + accum >>= 64; + accum_bits -= 64; + } + } + + let remainder = mask_chunks.remainder_bits(); + if remainder != 0 { + let src_rem = src_chunks.remainder_bits(); + let popcount = remainder.count_ones(); + if popcount > 0 { + let extracted = pext_fn(src_rem, remainder); + accum |= (extracted as u128) << accum_bits; + accum_bits += popcount; + if accum_bits >= 64 { + unsafe { out_ptr.add(out_idx).write(accum as u64) }; + out_idx += 1; + accum >>= 64; + accum_bits -= 64; + } + } + } + + if accum_bits > 0 { + unsafe { out_ptr.add(out_idx).write(accum as u64) }; + out_idx += 1; + } + + // SAFETY: we wrote exactly out_idx words, which is <= out_u64s + 1 = capacity. + unsafe { output.set_len(out_idx) }; + + let byte_len = true_count.div_ceil(8); + let bytes: Vec = unsafe { + let mut v = std::mem::ManuallyDrop::new(output); + let ptr = v.as_mut_ptr() as *mut u8; + let cap = v.capacity() * 8; + Vec::from_raw_parts(ptr, byte_len, cap) + }; + + BitBuffer::new(bytes.into(), true_count) +} + +/// Byte-level LUT PEXT fallback. +/// +/// Processes each byte of the u64 independently using a precomputed 256-entry +/// lookup table per mask byte. Each byte PEXT is a single table lookup with no +/// data dependencies between bytes, making this faster than the parallel-prefix +/// approach (~12ns vs ~18ns per word). +#[inline(always)] +pub fn pext_fallback(src: u64, mask: u64) -> u64 { + pext_byte_lut(src, mask) +} + +/// Precomputed lookup table for 8-bit PEXT. +/// +/// `BYTE_PEXT_LUT[mask_byte]` is a 256-byte table mapping `src_byte` to the +/// extracted bits. Total size: 256 * 256 = 64KB, fits in L1 cache. +#[allow(clippy::cast_possible_truncation)] +static BYTE_PEXT_LUT: &[u8; 256 * 256] = &{ + let mut lut = [0u8; 256 * 256]; + let mut mask: usize = 0; + while mask < 256 { + let mut src: usize = 0; + while src < 256 { + let mut result = 0u8; + let mut bit = 0u8; + // mask and src are always < 256, so truncation to u8 is safe. + let mut m = mask as u8; + let s = src as u8; + let mut pos: u8 = 0; + while m != 0 { + if m & 1 != 0 { + if s & (1 << pos) != 0 { + result |= 1 << bit; + } + bit += 1; + } + m >>= 1; + pos += 1; + } + lut[mask * 256 + src] = result; + src += 1; + } + mask += 1; + } + lut +}; + +/// Byte-level PEXT using precomputed lookup table. +#[inline(always)] +fn pext_byte_lut(src: u64, mask: u64) -> u64 { + let src_bytes = src.to_le_bytes(); + let mask_bytes = mask.to_le_bytes(); + + let mut result: u64 = 0; + let mut bit_offset: u32 = 0; + + // Unroll the byte loop for performance. + macro_rules! process_byte { + ($i:expr) => { + let m = mask_bytes[$i]; + if m != 0 { + let extracted = BYTE_PEXT_LUT[(m as usize) * 256 + (src_bytes[$i] as usize)]; + result |= (extracted as u64) << bit_offset; + bit_offset += m.count_ones(); + } + }; + } + + process_byte!(0); + process_byte!(1); + process_byte!(2); + process_byte!(3); + process_byte!(4); + process_byte!(5); + process_byte!(6); + process_byte!(7); + + let _ = bit_offset; + result +} + +#[cfg(test)] +mod tests { use itertools::Itertools; + use rstest::rstest; use vortex_mask::Mask; use super::*; @@ -87,22 +317,40 @@ mod test { } #[test] - fn filter_bool_by_slice_test() { + fn filter_bool_sparse_index_mask() { let arr = BoolArray::from_iter([true, true, false]); + let mask = Mask::from_indices(3, [0, 2]); - let filtered = filter_slices(&arr.to_bit_buffer(), 2, [(0, 1), (2, 3)].into_iter()); - assert_eq!(vec![true, false], filtered.iter().collect_vec()) + let filtered = arr.filter(mask).unwrap(); + assert_arrays_eq!(filtered, BoolArray::from_iter([true, false])); } #[test] - fn filter_bool_by_index_test() { + fn filter_bool_sparse_slice_mask() { let arr = BoolArray::from_iter([true, true, false]); + let mask = Mask::from_slices(3, vec![(0, 1), (2, 3)]); - let filtered = filter_indices(&arr.to_bit_buffer(), &[0, 2]); - assert_eq!(vec![true, false], filtered.iter().collect_vec()) + let filtered = arr.filter(mask).unwrap(); + assert_arrays_eq!(filtered, BoolArray::from_iter([true, false])); } - use rstest::rstest; + #[test] + fn filter_bool_sparse_buffer_mask() { + let arr = BoolArray::from_iter([true, true, false]); + let mask = Mask::from_buffer(BitBuffer::from_iter([true, false, true])); + + let filtered = arr.filter(mask).unwrap(); + assert_arrays_eq!(filtered, BoolArray::from_iter([true, false])); + } + + #[test] + fn filter_bool_by_buffer() { + let arr = BoolArray::from_iter([true, true, false]); + + let filtered = + filter_bitbuffer_by_mask(&arr.to_bit_buffer(), &BitBuffer::from_indices(3, [0, 2]), 2); + assert_eq!(vec![true, false], filtered.iter().collect_vec()) + } #[rstest] #[case(BoolArray::from_iter([true, false, true, true, false]))] @@ -114,4 +362,48 @@ mod test { fn test_filter_bool_conformance(#[case] array: BoolArray) { test_filter_conformance(&array.into_array()); } + + #[cfg(target_arch = "x86_64")] + #[test] + fn test_pext_fallback_matches_hardware() { + use std::arch::x86_64::_pext_u64; + + use super::pext_fallback; + + if !std::arch::is_x86_feature_detected!("bmi2") { + return; + } + let test_cases: Vec<(u64, u64)> = vec![ + (0, 0), + (u64::MAX, u64::MAX), + (u64::MAX, 0), + (0, u64::MAX), + (0xAAAA_AAAA_AAAA_AAAA, 0x5555_5555_5555_5555), + (0x5555_5555_5555_5555, 0xAAAA_AAAA_AAAA_AAAA), + (0xDEAD_BEEF_CAFE_BABE, 0xFFFF_0000_FFFF_0000), + (0x1234_5678_9ABC_DEF0, 0xF0F0_F0F0_F0F0_F0F0), + (u64::MAX, 1), + (u64::MAX, 1u64 << 63), + (0x8000_0000_0000_0001, 0x8000_0000_0000_0001), + ]; + for (src, mask) in test_cases { + let hw = unsafe { _pext_u64(src, mask) }; + let sw = pext_fallback(src, mask); + assert_eq!(hw, sw, "mismatch for src={src:#018x} mask={mask:#018x}"); + } + let mut rng = 0xDEAD_BEEF_u64; + for _ in 0..1000 { + rng ^= rng << 13; + rng ^= rng >> 7; + rng ^= rng << 17; + let src = rng; + rng ^= rng << 13; + rng ^= rng >> 7; + rng ^= rng << 17; + let mask = rng; + let hw = unsafe { _pext_u64(src, mask) }; + let sw = pext_fallback(src, mask); + assert_eq!(hw, sw, "mismatch for src={src:#018x} mask={mask:#018x}"); + } + } } diff --git a/vortex-array/src/arrays/filter/execute/bitbuffer.rs b/vortex-array/src/arrays/filter/execute/bitbuffer.rs index 1ad07801967..bb7861d21ce 100644 --- a/vortex-array/src/arrays/filter/execute/bitbuffer.rs +++ b/vortex-array/src/arrays/filter/execute/bitbuffer.rs @@ -4,10 +4,10 @@ //! [`BitBuffer`] filtering algorithms. use vortex_buffer::BitBuffer; -use vortex_buffer::BitBufferMut; -use vortex_buffer::get_bit; use vortex_mask::MaskValues; +use crate::arrays::bool::compute::filter::filter_bitbuffer_by_mask; + /// Filter a [`BitBuffer`] by [`MaskValues`], returning a new [`BitBuffer`]. pub(super) fn filter_bit_buffer(bb: &BitBuffer, mask: &MaskValues) -> BitBuffer { assert_eq!( @@ -16,51 +16,22 @@ pub(super) fn filter_bit_buffer(bb: &BitBuffer, mask: &MaskValues) -> BitBuffer "Selection mask length must equal the mask length" ); - // BitBuffer filtering always uses indices for simplicity. - filter_bitbuffer_by_indices(bb, mask.indices()) -} - -fn filter_bitbuffer_by_indices(bb: &BitBuffer, indices: &[usize]) -> BitBuffer { - let bools = bb.inner().as_ref(); - let bit_offset = bb.offset(); - - // FIXME(ngates): this is slower than it could be! - BitBufferMut::collect_bool(indices.len(), |idx| { - let idx = *unsafe { indices.get_unchecked(idx) }; - get_bit(bools, bit_offset + idx) // Panics if out of bounds. - }) - .freeze() -} - -#[expect(unused)] -fn filter_bitbuffer_by_slices(bb: &BitBuffer, slices: &[(usize, usize)]) -> BitBuffer { - let bools = bb.inner().as_ref(); - let bit_offset = bb.offset(); - let output_len: usize = slices.iter().map(|(start, end)| end - start).sum(); - - let mut out = BitBufferMut::with_capacity(output_len); - - // FIXME(ngates): this is slower than it could be! - for &(start, end) in slices { - for idx in start..end { - out.append(get_bit(bools, bit_offset + idx)); // Panics if out of bounds. - } - } - - out.freeze() + filter_bitbuffer_by_mask(bb, mask.bit_buffer(), mask.true_count()) } #[cfg(test)] mod tests { use vortex_buffer::bitbuffer; + use vortex_mask::Mask; - use crate::arrays::filter::execute::bitbuffer::filter_bitbuffer_by_indices; + use super::filter_bit_buffer; #[test] - fn filter_bool_by_index_test() { + fn filter_bool_by_mask_test() { let buf = bitbuffer![1 1 0]; - let indices = [0usize, 2]; - let filtered = filter_bitbuffer_by_indices(&buf, &indices); + let mask = Mask::from_iter([true, false, true]); + let mask_values = mask.values().unwrap(); + let filtered = filter_bit_buffer(&buf, mask_values); assert_eq!(2, filtered.len()); assert_eq!(filtered, bitbuffer![1 0]) } diff --git a/vortex-array/src/arrays/filter/execute/byte_compress.rs b/vortex-array/src/arrays/filter/execute/byte_compress.rs new file mode 100644 index 00000000000..6c58ce4f33b --- /dev/null +++ b/vortex-array/src/arrays/filter/execute/byte_compress.rs @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Byte-level compress for primitive filtering using a `1 << 8 = 256`-entry lookup table. +//! +//! For each byte of the mask (8 bits -> 8 source elements), a precomputed +//! permutation table compacts the selected bytes in a single indexed copy, +//! avoiding the overhead of materializing indices or slices. + +use std::mem::size_of; + +use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; +use vortex_mask::MaskValues; + +const BYTE_COMPRESS_DENSITY_THRESHOLD: f64 = 0.5; + +/// For each mask byte (0..256), stores the element indices to keep and the count. +/// +/// `BYTE_COMPRESS_LUT[mask_byte]` = `([i0, i1, ..., i7], popcount)` where +/// `i0..i_{popcount-1}` are the positions of set bits in `mask_byte`. +/// +/// Total size: 256 * 9 = 2304 bytes, which trivially fits in L1 cache. +static BYTE_COMPRESS_LUT: &[([u8; 8], u8); 256] = &{ + let mut lut = [([0u8; 8], 0u8); 256]; + let mut mask: usize = 0; + while mask < 256 { + let mut indices = [0u8; 8]; + let mut count: u8 = 0; + let mut bit: u8 = 0; + while bit < 8 { + if mask & (1 << bit) != 0 { + indices[count as usize] = bit; + count += 1; + } + bit += 1; + } + lut[mask] = (indices, count); + mask += 1; + } + lut +}; + +/// Filter a `Buffer` using the byte-compress LUT. +/// +/// Processes the mask one byte at a time (8 source elements per byte), +/// using a precomputed permutation to compact selected elements. +pub(super) fn filter_buffer(buffer: Buffer, mask: &MaskValues) -> Buffer { + debug_assert_eq!(buffer.len(), mask.len()); + + let src = buffer.as_slice(); + let true_count = mask.true_count(); + + if true_count == 0 { + return Buffer::empty(); + } + + let mask_buffer = mask.bit_buffer(); + let mask_bytes = mask_buffer.inner().as_ref(); + let mask_offset = mask_buffer.offset(); + + // Fast path: byte-wide values benefit from avoiding index materialization more often. Wider + // values need enough selected values to justify scanning every mask byte directly. + if size_of::() == 1 || mask.density() >= BYTE_COMPRESS_DENSITY_THRESHOLD { + return filter_bitpacked(src, mask_bytes, mask_offset, true_count); + } + + // Slow path: lower-density wide values are better handled by the generic path. + super::slice::filter_slice_by_mask_values(src, mask) +} + +fn filter_bitpacked( + src: &[T], + mask_bytes: &[u8], + mask_offset: usize, + true_count: usize, +) -> Buffer { + let mut out = BufferMut::::with_capacity(true_count); + let mut write_pos: usize = 0; + + if mask_offset == 0 { + filter_aligned_into(src, mask_bytes, &mut out, &mut write_pos); + } else { + let head_len = (8 - mask_offset).min(src.len()); + let head_mask = (mask_bytes[0] >> mask_offset) & low_bits_mask(head_len); + filter_chunk_into(&src[..head_len], head_mask, &mut out, &mut write_pos); + filter_aligned_into(&src[head_len..], &mask_bytes[1..], &mut out, &mut write_pos); + } + + debug_assert_eq!(write_pos, true_count); + // SAFETY: we wrote exactly true_count elements. + unsafe { out.set_len(true_count) }; + out.freeze() +} + +/// Aligned fast path: mask bits start at byte boundary. +fn filter_aligned_into( + src: &[T], + mask_bytes: &[u8], + out: &mut BufferMut, + write_pos: &mut usize, +) { + let full_bytes = src.len() / 8; + let remainder = src.len() % 8; + + for i in 0..full_bytes { + let m = mask_bytes[i]; + if m == 0 { + continue; + } + let chunk = &src[i * 8..i * 8 + 8]; + filter_chunk_into(chunk, m, out, write_pos); + } + + // Handle the final partial chunk. + if remainder > 0 { + let m = mask_bytes[full_bytes] & low_bits_mask(remainder); + if m != 0 { + let chunk = &src[full_bytes * 8..]; + filter_chunk_into(chunk, m, out, write_pos); + } + } +} + +fn filter_chunk_into( + chunk: &[T], + mask_byte: u8, + out: &mut BufferMut, + write_pos: &mut usize, +) { + if mask_byte == 0 { + return; + } + + let out_ptr = out.spare_capacity_mut().as_mut_ptr(); + if chunk.len() == 8 && mask_byte == 0xFF { + // All 8 selected, so bulk copy. + // SAFETY: write_pos + 8 <= capacity. + unsafe { + std::ptr::copy_nonoverlapping(chunk.as_ptr(), out_ptr.add(*write_pos).cast::(), 8); + } + *write_pos += 8; + return; + } + + let (perm, count) = &BYTE_COMPRESS_LUT[mask_byte as usize]; + let count = *count as usize; + debug_assert_eq!(mask_byte & !low_bits_mask(chunk.len()), 0); + // SAFETY: perm indices are all < chunk.len(), write_pos + count <= capacity. + unsafe { + for j in 0..count { + out_ptr + .add(*write_pos + j) + .cast::() + .write(*chunk.get_unchecked(*perm.get_unchecked(j) as usize)); + } + } + *write_pos += count; +} + +fn low_bits_mask(bits: usize) -> u8 { + debug_assert!(bits <= 8); + if bits == 8 { + u8::MAX + } else { + (1u8 << bits) - 1 + } +} + +#[cfg(test)] +#[allow(clippy::unwrap_used, clippy::cast_possible_truncation)] +mod tests { + use vortex_buffer::BitBuffer; + use vortex_buffer::buffer; + use vortex_mask::Mask; + + use super::*; + + fn mask_values(mask: &Mask) -> &MaskValues { + match mask { + Mask::Values(v) => v.as_ref(), + _ => panic!("expected Mask::Values"), + } + } + + fn offset_mask(values: [bool; N], offset: usize) -> Mask { + let bit_buffer = + BitBuffer::from_iter(std::iter::repeat_n(false, offset).chain(values.iter().copied())); + Mask::from_buffer(BitBuffer::new_with_offset( + bit_buffer.inner().clone(), + values.len(), + offset, + )) + } + + #[test] + fn test_filter_all_selected() { + let buf = buffer![1u8, 2, 3, 4, 5, 6, 7, 8, 9]; + let mask = Mask::from_iter([true, true, true, true, true, true, true, true, false]); + let result = filter_buffer(buf, mask_values(&mask)); + assert_eq!(result, buffer![1u8, 2, 3, 4, 5, 6, 7, 8]); + } + + #[test] + fn test_filter_mostly_false() { + let buf = buffer![1u8, 2, 3, 4, 5, 6, 7, 8, 9]; + let mask = Mask::from_iter([false, false, false, false, false, false, false, false, true]); + let result = filter_buffer(buf, mask_values(&mask)); + assert_eq!(result, buffer![9u8]); + } + + #[test] + fn test_filter_alternating() { + let buf = buffer![10u8, 20, 30, 40, 50, 60, 70, 80]; + let mask = Mask::from_iter([true, false, true, false, true, false, true, false]); + let result = filter_buffer(buf, mask_values(&mask)); + assert_eq!(result, buffer![10u8, 30, 50, 70]); + } + + #[test] + fn test_filter_with_remainder() { + let buf = buffer![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let mask = Mask::from_iter([ + true, false, true, false, true, false, true, false, true, true, + ]); + let result = filter_buffer(buf, mask_values(&mask)); + assert_eq!(result, buffer![1u8, 3, 5, 7, 9, 10]); + } + + #[test] + fn test_filter_large() -> vortex_error::VortexResult<()> { + let data: Vec = (0..1000).map(|i| (i % 256) as u8).collect(); + let buf = Buffer::from(BufferMut::from_iter(data.iter().copied())); + let mask = Mask::from_iter((0..1000).map(|i| i % 3 == 0)); + let result = filter_buffer(buf, mask_values(&mask)); + let expected: Vec = data.iter().copied().step_by(3).collect(); + assert_eq!(result.as_slice(), &expected[..]); + Ok(()) + } + + #[test] + fn test_filter_signed_and_wider_integers() { + let mask = Mask::from_iter([true, false, true, true, false, true, false, true, true]); + + let i8_result = filter_buffer( + buffer![-5i8, -4, -3, -2, -1, 0, 1, 2, 3], + mask_values(&mask), + ); + assert_eq!(i8_result, buffer![-5i8, -3, -2, 0, 2, 3]); + + let u16_result = filter_buffer( + buffer![10u16, 20, 30, 40, 50, 60, 70, 80, 90], + mask_values(&mask), + ); + assert_eq!(u16_result, buffer![10u16, 30, 40, 60, 80, 90]); + + let i32_result = filter_buffer( + buffer![-100i32, -50, 0, 50, 100, 150, 200, 250, 300], + mask_values(&mask), + ); + assert_eq!(i32_result, buffer![-100i32, 0, 50, 150, 250, 300]); + } + + #[test] + fn test_filter_unaligned_byte_mask() { + let mask = offset_mask( + [ + false, false, true, false, false, false, false, false, true, false, false, + ], + 3, + ); + + let result = filter_buffer( + buffer![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + mask_values(&mask), + ); + assert_eq!(result, buffer![3u8, 9]); + } + + #[test] + fn test_filter_unaligned_wide_mask() { + let mask = offset_mask( + [ + true, false, true, true, false, true, false, true, true, false, false, true, + ], + 5, + ); + + let result = filter_buffer( + buffer![10u16, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120], + mask_values(&mask), + ); + assert_eq!(result, buffer![10u16, 30, 40, 60, 80, 90, 120]); + } +} diff --git a/vortex-array/src/arrays/filter/execute/mod.rs b/vortex-array/src/arrays/filter/execute/mod.rs index ba5dffd4cd7..e1db2471780 100644 --- a/vortex-array/src/arrays/filter/execute/mod.rs +++ b/vortex-array/src/arrays/filter/execute/mod.rs @@ -31,6 +31,7 @@ use crate::validity::Validity; mod bitbuffer; mod bool; mod buffer; +pub(crate) mod byte_compress; mod decimal; mod fixed_size_list; mod listview; diff --git a/vortex-array/src/arrays/filter/execute/primitive.rs b/vortex-array/src/arrays/filter/execute/primitive.rs index 0aae5738da7..276e08b5526 100644 --- a/vortex-array/src/arrays/filter/execute/primitive.rs +++ b/vortex-array/src/arrays/filter/execute/primitive.rs @@ -8,8 +8,12 @@ use vortex_mask::MaskValues; use crate::arrays::PrimitiveArray; use crate::arrays::filter::execute::buffer; +use crate::arrays::filter::execute::byte_compress; use crate::arrays::filter::execute::filter_validity; +use crate::dtype::NativePType; +use crate::dtype::PType; use crate::match_each_native_ptype; +use crate::validity::Validity; pub fn filter_primitive(array: &PrimitiveArray, mask: &Arc) -> PrimitiveArray { let validity = array @@ -17,13 +21,34 @@ pub fn filter_primitive(array: &PrimitiveArray, mask: &Arc) -> Primi .vortex_expect("primitive validity should be derivable"); let filtered_validity = filter_validity(validity, mask); - match_each_native_ptype!(array.ptype(), |T| { - let filtered_buffer = buffer::filter_buffer(array.to_buffer::(), mask.as_ref()); + match array.ptype() { + // Byte-compress avoids materializing indices/slices and processes 8 elements per mask byte. + PType::U8 => filter_byte_compress::(array, filtered_validity, mask), + PType::I8 => filter_byte_compress::(array, filtered_validity, mask), + PType::U16 => filter_byte_compress::(array, filtered_validity, mask), + PType::I16 => filter_byte_compress::(array, filtered_validity, mask), + PType::U32 => filter_byte_compress::(array, filtered_validity, mask), + PType::I32 => filter_byte_compress::(array, filtered_validity, mask), + _ => match_each_native_ptype!(array.ptype(), |T| { + let filtered_buffer = buffer::filter_buffer(array.to_buffer::(), mask.as_ref()); - // SAFETY: We filter both the validity and the buffer with the same mask, so they must have - // the same length. - unsafe { PrimitiveArray::new_unchecked(filtered_buffer, filtered_validity) } - }) + // SAFETY: We filter both the validity and the buffer with the same mask, so they must + // have the same length. + unsafe { PrimitiveArray::new_unchecked(filtered_buffer, filtered_validity) } + }), + } +} + +fn filter_byte_compress( + array: &PrimitiveArray, + filtered_validity: Validity, + mask: &Arc, +) -> PrimitiveArray { + let filtered_buffer = byte_compress::filter_buffer(array.to_buffer::(), mask.as_ref()); + + // SAFETY: We filter both the validity and the buffer with the same mask, so they must have the + // same length. + unsafe { PrimitiveArray::new_unchecked(filtered_buffer, filtered_validity) } } #[cfg(test)] @@ -65,6 +90,9 @@ mod test { } #[rstest] + #[case(PrimitiveArray::from_iter([-2i8, -1, 0, 1, 2]))] + #[case(PrimitiveArray::from_iter([1u16, 2, 3, 4, 5]))] + #[case(PrimitiveArray::from_iter([-2i16, -1, 0, 1, 2]))] #[case(PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]))] #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]))] #[case(PrimitiveArray::from_iter([42u64]))] diff --git a/vortex-array/src/arrays/filter/rules.rs b/vortex-array/src/arrays/filter/rules.rs index 45e938d7768..ffa5c64bd61 100644 --- a/vortex-array/src/arrays/filter/rules.rs +++ b/vortex-array/src/arrays/filter/rules.rs @@ -12,34 +12,23 @@ use crate::arrays::Filter; use crate::arrays::Struct; use crate::arrays::StructArray; use crate::arrays::filter::FilterArrayExt; +use crate::arrays::filter::FilterReduce; +use crate::arrays::filter::FilterReduceAdaptor; use crate::arrays::struct_::StructDataParts; -use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ArrayReduceRule; use crate::optimizer::rules::ParentRuleSet; use crate::optimizer::rules::ReduceRuleSet; pub(super) const PARENT_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&FilterFilterRule)]); + ParentRuleSet::new(&[ParentRuleSet::lift(&FilterReduceAdaptor(Filter))]); pub(super) const RULES: ReduceRuleSet = ReduceRuleSet::new(&[&TrivialFilterRule, &FilterStructRule]); -/// A simple redecution rule that simplifies a [`FilterArray`] whose child is also a -/// [`FilterArray`]. -#[derive(Debug)] -struct FilterFilterRule; - -impl ArrayParentReduceRule for FilterFilterRule { - type Parent = Filter; - - fn reduce_parent( - &self, - child: ArrayView<'_, Filter>, - parent: ArrayView<'_, Filter>, - _child_idx: usize, - ) -> VortexResult> { - let combined_mask = child.mask.intersect_by_rank(&parent.mask); - let new_array = child.child().filter(combined_mask)?; +impl FilterReduce for Filter { + fn filter(array: ArrayView<'_, Self>, mask: &Mask) -> VortexResult> { + let combined_mask = array.mask.intersect_by_rank(mask); + let new_array = array.child().filter(combined_mask)?; Ok(Some(new_array)) } diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index 006bdafaa64..8c8c4991042 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -172,6 +172,10 @@ impl vortex_mask::MaskValues pub fn vortex_mask::MaskValues::bit_buffer(&self) -> &vortex_buffer::bit::buf::BitBuffer +pub fn vortex_mask::MaskValues::cached_indices(&self) -> core::option::Option<&[usize]> + +pub fn vortex_mask::MaskValues::cached_slices(&self) -> core::option::Option<&[(usize, usize)]> + pub fn vortex_mask::MaskValues::density(&self) -> f64 pub fn vortex_mask::MaskValues::indices(&self) -> &[usize] diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index 30c05f8f410..41901ba2ba3 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -765,6 +765,15 @@ impl MaskValues { }) } + /// Returns cached index positions when this mask already has them materialized. + /// + /// Unlike [`Self::indices`], this does not build the index vector from another + /// representation. + #[inline] + pub fn cached_indices(&self) -> Option<&[usize]> { + self.indices.get().map(Vec::as_slice) + } + /// Constructs a slices vector from one of the other representations. #[inline] pub fn slices(&self) -> &[(usize, usize)] { @@ -777,6 +786,15 @@ impl MaskValues { }) } + /// Returns cached true-value ranges when this mask already has them materialized. + /// + /// Unlike [`Self::slices`], this does not build the slice vector from another + /// representation. + #[inline] + pub fn cached_slices(&self) -> Option<&[(usize, usize)]> { + self.slices.get().map(Vec::as_slice) + } + /// Return an iterator over either indices or slices of the mask based on a density threshold. #[inline] pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> { diff --git a/vortex-mask/src/tests.rs b/vortex-mask/src/tests.rs index d1496fcb30f..65687afef15 100644 --- a/vortex-mask/src/tests.rs +++ b/vortex-mask/src/tests.rs @@ -270,6 +270,24 @@ fn test_mask_values_threshold_iter() { } } +#[test] +fn test_mask_values_cached_representations() { + let from_buffer = Mask::from_buffer(BitBuffer::from_iter([true, false, true])); + let values = from_buffer.values().unwrap(); + assert!(values.cached_indices().is_none()); + assert!(values.cached_slices().is_none()); + + let from_indices = Mask::from_indices(5, [1, 3]); + let values = from_indices.values().unwrap(); + assert_eq!(values.cached_indices(), Some([1, 3].as_slice())); + assert!(values.cached_slices().is_none()); + + let from_slices = Mask::from_slices(6, vec![(1, 3), (5, 6)]); + let values = from_slices.values().unwrap(); + assert!(values.cached_indices().is_none()); + assert_eq!(values.cached_slices(), Some([(1, 3), (5, 6)].as_slice())); +} + #[test] fn test_mask_values_is_empty() { let empty_mask = Mask::from_buffer(BitBuffer::new_unset(0));