Grok 10.0.5
vqsort-inl.h
Go to the documentation of this file.
1// Copyright 2021 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Normal include guard for target-independent parts
17#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
18#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
19
20#ifndef VQSORT_PRINT
21#define VQSORT_PRINT 0
22#endif
23
24// Makes it harder for adversaries to predict our sampling locations, at the
25// cost of 1-2% increased runtime.
26#ifndef VQSORT_SECURE_RNG
27#define VQSORT_SECURE_RNG 0
28#endif
29
30#if VQSORT_SECURE_RNG
31#include "third_party/absl/random/random.h"
32#endif
33
34#include <stdio.h> // unconditional #include so we can use if(VQSORT_PRINT).
35#include <string.h> // memcpy
36
37#include "hwy/cache_control.h" // Prefetch
38#include "hwy/contrib/sort/vqsort.h" // Fill24Bytes
39
40#if HWY_IS_MSAN
41#include <sanitizer/msan_interface.h>
42#endif
43
44#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
45
46// Per-target
47#if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \
48 defined(HWY_TARGET_TOGGLE)
49#ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
50#undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
51#else
52#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
53#endif
54
55#if VQSORT_PRINT
56#include "hwy/print-inl.h"
57#endif
58
61// Placeholder for internal instrumentation. Do not remove.
62#include "hwy/highway.h"
63
65namespace hwy {
66namespace HWY_NAMESPACE {
67namespace detail {
68
70
71// Wrappers to avoid #if in user code (interferes with code folding)
72
73HWY_INLINE void UnpoisonIfMemorySanitizer(void* p, size_t bytes) {
74#if HWY_IS_MSAN
75 __msan_unpoison(p, bytes);
76#else
77 (void)p;
78 (void)bytes;
79#endif
80}
81
82template <class D>
83HWY_INLINE void MaybePrintVector(D d, const char* label, Vec<D> v,
84 size_t start = 0, size_t max_lanes = 16) {
85#if VQSORT_PRINT >= 2 // Print is only defined #if
86 Print(d, label, v, start, max_lanes);
87#else
88 (void)d;
89 (void)label;
90 (void)v;
91 (void)start;
92 (void)max_lanes;
93#endif
94}
95
96// ------------------------------ HeapSort
97
98template <class Traits, typename T>
99void SiftDown(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes,
100 size_t start) {
101 constexpr size_t N1 = st.LanesPerKey();
102 const FixedTag<T, N1> d;
103
104 while (start < num_lanes) {
105 const size_t left = 2 * start + N1;
106 const size_t right = 2 * start + 2 * N1;
107 if (left >= num_lanes) break;
108 size_t idx_larger = start;
109 const auto key_j = st.SetKey(d, lanes + start);
110 if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, lanes + left)))) {
111 idx_larger = left;
112 }
113 if (right < num_lanes &&
114 AllTrue(d, st.Compare(d, st.SetKey(d, lanes + idx_larger),
115 st.SetKey(d, lanes + right)))) {
116 idx_larger = right;
117 }
118 if (idx_larger == start) break;
119 st.Swap(lanes + start, lanes + idx_larger);
120 start = idx_larger;
121 }
122}
123
124// Heapsort: O(1) space, O(N*logN) worst-case comparisons.
125// Based on LLVM sanitizer_common.h, licensed under Apache-2.0.
126template <class Traits, typename T>
127void HeapSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes) {
128 constexpr size_t N1 = st.LanesPerKey();
129
130 if (num_lanes < 2 * N1) return;
131
132 // Build heap.
133 for (size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) {
134 SiftDown(st, lanes, num_lanes, i);
135 }
136
137 for (size_t i = num_lanes - N1; i != 0; i -= N1) {
138 // Swap root with last
139 st.Swap(lanes + 0, lanes + i);
140
141 // Sift down the new root.
142 SiftDown(st, lanes, i, 0);
143 }
144}
145
146#if VQSORT_ENABLED || HWY_IDE
147
148// ------------------------------ BaseCase
149
150// Sorts `keys` within the range [0, num) via sorting network.
151template <class D, class Traits, typename T>
152HWY_INLINE void BaseCase(D d, Traits st, T* HWY_RESTRICT keys,
153 T* HWY_RESTRICT keys_end, size_t num,
154 T* HWY_RESTRICT buf) {
155 const size_t N = Lanes(d);
156 using V = decltype(Zero(d));
157
158 // _Nonzero32 requires num - 1 != 0.
159 if (HWY_UNLIKELY(num <= 1)) return;
160
161 // Reshape into a matrix with kMaxRows rows, and columns limited by the
162 // 1D `num`, which is upper-bounded by the vector width (see BaseCaseNum).
163 const size_t num_pow2 = size_t{1}
165 static_cast<uint32_t>(num - 1)));
166 HWY_DASSERT(num <= num_pow2 && num_pow2 <= Constants::BaseCaseNum(N));
167 const size_t cols =
168 HWY_MAX(st.LanesPerKey(), num_pow2 >> Constants::kMaxRowsLog2);
169 HWY_DASSERT(cols <= N);
170
171 // We can avoid padding and load/store directly to `keys` after checking the
172 // original input array has enough space. Except at the right border, it's OK
173 // to sort more than the current sub-array. Even if we sort across a previous
174 // partition point, we know that keys will not migrate across it. However, we
175 // must use the maximum size of the sorting network, because the StoreU of its
176 // last vector would otherwise write invalid data starting at kMaxRows * cols.
177 const size_t N_sn = Lanes(CappedTag<T, Constants::kMaxCols>());
178 if (HWY_LIKELY(keys + N_sn * Constants::kMaxRows <= keys_end)) {
179 SortingNetwork(st, keys, N_sn);
180 return;
181 }
182
183 // Copy `keys` to `buf`.
184 size_t i;
185 for (i = 0; i + N <= num; i += N) {
186 Store(LoadU(d, keys + i), d, buf + i);
187 }
188 SafeCopyN(num - i, d, keys + i, buf + i);
189 i = num;
190
191 // Fill with padding - last in sort order, not copied to keys.
192 const V kPadding = st.LastValue(d);
193 // Initialize an extra vector because SortingNetwork loads full vectors,
194 // which may exceed cols*kMaxRows.
195 for (; i < (cols * Constants::kMaxRows + N); i += N) {
196 StoreU(kPadding, d, buf + i);
197 }
198
199 SortingNetwork(st, buf, cols);
200
201 for (i = 0; i + N <= num; i += N) {
202 StoreU(Load(d, buf + i), d, keys + i);
203 }
204 SafeCopyN(num - i, d, buf + i, keys + i);
205}
206
207// ------------------------------ Partition
208
209// Consumes from `keys` until a multiple of kUnroll*N remains.
210// Temporarily stores the right side into `buf`, then moves behind `num`.
211// Returns the number of keys consumed from the left side.
212template <class D, class Traits, class T>
213HWY_INLINE size_t PartitionToMultipleOfUnroll(D d, Traits st,
214 T* HWY_RESTRICT keys, size_t& num,
215 const Vec<D> pivot,
216 T* HWY_RESTRICT buf) {
217 constexpr size_t kUnroll = Constants::kPartitionUnroll;
218 const size_t N = Lanes(d);
219 size_t readL = 0;
220 T* HWY_RESTRICT posL = keys;
221 size_t bufR = 0;
222 // Partition requires both a multiple of kUnroll*N and at least
223 // 2*kUnroll*N for the initial loads. If less, consume all here.
224 const size_t num_rem =
225 (num < 2 * kUnroll * N) ? num : (num & (kUnroll * N - 1));
226 size_t i = 0;
227 for (; i + N <= num_rem; i += N) {
228 const Vec<D> vL = LoadU(d, keys + readL);
229 readL += N;
230
231 const auto comp = st.Compare(d, pivot, vL);
232 posL += CompressBlendedStore(vL, Not(comp), d, posL);
233 bufR += CompressStore(vL, comp, d, buf + bufR);
234 }
235 // Last iteration: only use valid lanes.
236 if (HWY_LIKELY(i != num_rem)) {
237 const auto mask = FirstN(d, num_rem - i);
238 const Vec<D> vL = LoadU(d, keys + readL);
239
240 const auto comp = st.Compare(d, pivot, vL);
241 posL += CompressBlendedStore(vL, AndNot(comp, mask), d, posL);
242 bufR += CompressStore(vL, And(comp, mask), d, buf + bufR);
243 }
244
245 // MSAN seems not to understand CompressStore. buf[0, bufR) are valid.
246 UnpoisonIfMemorySanitizer(buf, bufR * sizeof(T));
247
248 // Everything we loaded was put into buf, or behind the current `posL`, after
249 // which there is space for bufR items. First move items from `keys + num` to
250 // `posL` to free up space, then copy `buf` into the vacated `keys + num`.
251 // A loop with masked loads from `buf` is insufficient - we would also need to
252 // mask from `keys + num`. Combining a loop with memcpy for the remainders is
253 // slower than just memcpy, so we use that for simplicity.
254 num -= bufR;
255 memcpy(posL, keys + num, bufR * sizeof(T));
256 memcpy(keys + num, buf, bufR * sizeof(T));
257 return static_cast<size_t>(posL - keys); // caller will shrink num by this.
258}
259
260template <class V>
261V OrXor(const V o, const V x1, const V x2) {
262 // TODO(janwas): add op so we can benefit from AVX-512 ternlog?
263 return Or(o, Xor(x1, x2));
264}
265
266// Note: we could track the OrXor of v and pivot to see if the entire left
267// partition is equal, but that happens rarely and thus is a net loss.
268template <class D, class Traits, typename T>
269HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec<D> v,
270 const Vec<D> pivot, T* HWY_RESTRICT keys,
271 size_t& writeL, size_t& remaining) {
272 const size_t N = Lanes(d);
273
274 const auto comp = st.Compare(d, pivot, v);
275
276 remaining -= N;
278 (HWY_MAX_BYTES == 16 && st.Is128())) {
279 // Non-native Compress (e.g. AVX2): we are able to partition a vector using
280 // a single Compress+two StoreU instead of two Compress[Blended]Store. The
281 // latter are more expensive. Because we store entire vectors, the contents
282 // between the updated writeL and writeR are ignored and will be overwritten
283 // by subsequent calls. This works because writeL and writeR are at least
284 // two vectors apart.
285 const auto lr = st.CompressKeys(v, comp);
286 const size_t num_left = N - CountTrue(d, comp);
287 StoreU(lr, d, keys + writeL);
288 // Now write the right-side elements (if any), such that the previous writeR
289 // is one past the end of the newly written right elements, then advance.
290 StoreU(lr, d, keys + remaining + writeL);
291 writeL += num_left;
292 } else {
293 // Native Compress[Store] (e.g. AVX3), which only keep the left or right
294 // side, not both, hence we require two calls.
295 const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL);
296 writeL += num_left;
297
298 (void)CompressBlendedStore(v, comp, d, keys + remaining + writeL);
299 }
300}
301
302template <class D, class Traits, typename T>
303HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec<D> v0,
304 const Vec<D> v1, const Vec<D> v2,
305 const Vec<D> v3, const Vec<D> pivot,
306 T* HWY_RESTRICT keys, size_t& writeL,
307 size_t& remaining) {
308 StoreLeftRight(d, st, v0, pivot, keys, writeL, remaining);
309 StoreLeftRight(d, st, v1, pivot, keys, writeL, remaining);
310 StoreLeftRight(d, st, v2, pivot, keys, writeL, remaining);
311 StoreLeftRight(d, st, v3, pivot, keys, writeL, remaining);
312}
313
314// Moves "<= pivot" keys to the front, and others to the back. pivot is
315// broadcasted. Time-critical!
316//
317// Aligned loads do not seem to be worthwhile (not bottlenecked by load ports).
318template <class D, class Traits, typename T>
319HWY_INLINE size_t Partition(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
320 const Vec<D> pivot, T* HWY_RESTRICT buf) {
321 using V = decltype(Zero(d));
322 const size_t N = Lanes(d);
323
324 // StoreLeftRight will CompressBlendedStore ending at `writeR`. Unless all
325 // lanes happen to be in the right-side partition, this will overrun `keys`,
326 // which triggers asan errors. Avoid by special-casing the last vector.
327 HWY_DASSERT(num > 2 * N); // ensured by HandleSpecialCases
328 num -= N;
329 size_t last = num;
330 const V vlast = LoadU(d, keys + last);
331
332 const size_t consumedL =
333 PartitionToMultipleOfUnroll(d, st, keys, num, pivot, buf);
334 keys += consumedL;
335 last -= consumedL;
336 num -= consumedL;
337 constexpr size_t kUnroll = Constants::kPartitionUnroll;
338
339 // Partition splits the vector into 3 sections, left to right: Elements
340 // smaller or equal to the pivot, unpartitioned elements and elements larger
341 // than the pivot. To write elements unconditionally on the loop body without
342 // overwriting existing data, we maintain two regions of the loop where all
343 // elements have been copied elsewhere (e.g. vector registers.). I call these
344 // bufferL and bufferR, for left and right respectively.
345 //
346 // These regions are tracked by the indices (writeL, writeR, left, right) as
347 // presented in the diagram below.
348 //
349 // writeL writeR
350 // \/ \/
351 // | <= pivot | bufferL | unpartitioned | bufferR | > pivot |
352 // \/ \/
353 // left right
354 //
355 // In the main loop body below we choose a side, load some elements out of the
356 // vector and move either `left` or `right`. Next we call into StoreLeftRight
357 // to partition the data, and the partitioned elements will be written either
358 // to writeR or writeL and the corresponding index will be moved accordingly.
359 //
360 // Note that writeR is not explicitly tracked as an optimization for platforms
361 // with conditional operations. Instead we track writeL and the number of
362 // elements left to process (`remaining`). From the diagram above we can see
363 // that:
364 // writeR - writeL = remaining => writeR = remaining + writeL
365 //
366 // Tracking `remaining` is advantageous because each iteration reduces the
367 // number of unpartitioned elements by a fixed amount, so we can compute
368 // `remaining` without data dependencies.
369 //
370 size_t writeL = 0;
371 size_t remaining = num;
372
373 const T* HWY_RESTRICT readL = keys;
374 const T* HWY_RESTRICT readR = keys + num;
375 // Cannot load if there were fewer than 2 * kUnroll * N.
376 if (HWY_LIKELY(num != 0)) {
377 HWY_DASSERT(num >= 2 * kUnroll * N);
378 HWY_DASSERT((num & (kUnroll * N - 1)) == 0);
379
380 // Make space for writing in-place by reading from readL/readR.
381 const V vL0 = LoadU(d, readL + 0 * N);
382 const V vL1 = LoadU(d, readL + 1 * N);
383 const V vL2 = LoadU(d, readL + 2 * N);
384 const V vL3 = LoadU(d, readL + 3 * N);
385 readL += kUnroll * N;
386 readR -= kUnroll * N;
387 const V vR0 = LoadU(d, readR + 0 * N);
388 const V vR1 = LoadU(d, readR + 1 * N);
389 const V vR2 = LoadU(d, readR + 2 * N);
390 const V vR3 = LoadU(d, readR + 3 * N);
391
392 // readL/readR changed above, so check again before the loop.
393 while (readL != readR) {
394 V v0, v1, v2, v3;
395
396 // Data-dependent but branching is faster than forcing branch-free.
397 const size_t capacityL =
398 static_cast<size_t>((readL - keys) - static_cast<ptrdiff_t>(writeL));
399 HWY_DASSERT(capacityL <= num); // >= 0
400 // Load data from the end of the vector with less data (front or back).
401 // The next paragraphs explain how this works.
402 //
403 // let block_size = (kUnroll * N)
404 // On the loop prelude we load block_size elements from the front of the
405 // vector and an additional block_size elements from the back. On each
406 // iteration k elements are written to the front of the vector and
407 // (block_size - k) to the back.
408 //
409 // This creates a loop invariant where the capacity on the front
410 // (capacityL) and on the back (capacityR) always add to 2 * block_size.
411 // In other words:
412 // capacityL + capacityR = 2 * block_size
413 // capacityR = 2 * block_size - capacityL
414 //
415 // This means that:
416 // capacityL < capacityR <=>
417 // capacityL < 2 * block_size - capacityL <=>
418 // 2 * capacityL < 2 * block_size <=>
419 // capacityL < block_size
420 //
421 // Thus the check on the next line is equivalent to capacityL > capacityR.
422 //
423 if (kUnroll * N < capacityL) {
424 readR -= kUnroll * N;
425 v0 = LoadU(d, readR + 0 * N);
426 v1 = LoadU(d, readR + 1 * N);
427 v2 = LoadU(d, readR + 2 * N);
428 v3 = LoadU(d, readR + 3 * N);
429 hwy::Prefetch(readR - 3 * kUnroll * N);
430 } else {
431 v0 = LoadU(d, readL + 0 * N);
432 v1 = LoadU(d, readL + 1 * N);
433 v2 = LoadU(d, readL + 2 * N);
434 v3 = LoadU(d, readL + 3 * N);
435 readL += kUnroll * N;
436 hwy::Prefetch(readL + 3 * kUnroll * N);
437 }
438
439 StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, remaining);
440 }
441
442 // Now finish writing the saved vectors to the middle.
443 StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, remaining);
444 StoreLeftRight4(d, st, vR0, vR1, vR2, vR3, pivot, keys, writeL, remaining);
445 }
446
447 // We have partitioned [left, right) such that writeL is the boundary.
448 HWY_DASSERT(remaining == 0);
449 // Make space for inserting vlast: move up to N of the first right-side keys
450 // into the unused space starting at last. If we have fewer, ensure they are
451 // the last items in that vector by subtracting from the *load* address,
452 // which is safe because we have at least two vectors (checked above).
453 const size_t totalR = last - writeL;
454 const size_t startR = totalR < N ? writeL + totalR - N : writeL;
455 StoreU(LoadU(d, keys + startR), d, keys + last);
456
457 // Partition vlast: write L, then R, into the single-vector gap at writeL.
458 const auto comp = st.Compare(d, pivot, vlast);
459 writeL += CompressBlendedStore(vlast, Not(comp), d, keys + writeL);
460 (void)CompressBlendedStore(vlast, comp, d, keys + writeL);
461
462 return consumedL + writeL;
463}
464
465// Returns true and partitions if [keys, keys + num) contains only {valueL,
466// valueR}. Otherwise, sets third to the first differing value; keys may have
467// been reordered and a regular Partition is still necessary.
468// Called from two locations, hence NOINLINE.
469template <class D, class Traits, typename T>
470HWY_NOINLINE bool MaybePartitionTwoValue(D d, Traits st, T* HWY_RESTRICT keys,
471 size_t num, const Vec<D> valueL,
472 const Vec<D> valueR, Vec<D>& third,
473 T* HWY_RESTRICT buf) {
474 const size_t N = Lanes(d);
475
476 size_t i = 0;
477 size_t writeL = 0;
478
479 // As long as all lanes are equal to L or R, we can overwrite with valueL.
480 // This is faster than first counting, then backtracking to fill L and R.
481 for (; i + N <= num; i += N) {
482 const Vec<D> v = LoadU(d, keys + i);
483 // It is not clear how to apply OrXor here - that can check if *both*
484 // comparisons are true, but here we want *either*. Comparing the unsigned
485 // min of differences to zero works, but is expensive for u64 prior to AVX3.
486 const Mask<D> eqL = st.EqualKeys(d, v, valueL);
487 const Mask<D> eqR = st.EqualKeys(d, v, valueR);
488 // At least one other value present; will require a regular partition.
489 // On AVX-512, Or + AllTrue are folded into a single kortest if we are
490 // careful with the FindKnownFirstTrue argument, see below.
491 if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) {
492 // If we repeat Or(eqL, eqR) here, the compiler will hoist it into the
493 // loop, which is a pessimization because this if-true branch is cold.
494 // We can defeat this via Not(Xor), which is equivalent because eqL and
495 // eqR cannot be true at the same time. Can we elide the additional Not?
496 // FindFirstFalse instructions are generally unavailable, but we can
497 // fuse Not and Xor/Or into one ExclusiveNeither.
498 const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR));
499 third = st.SetKey(d, keys + i + lane);
500 if (VQSORT_PRINT >= 2) {
501 fprintf(stderr, "found 3rd value at vec %zu; writeL %zu\n", i, writeL);
502 }
503 // 'Undo' what we did by filling the remainder of what we read with R.
504 for (; writeL + N <= i; writeL += N) {
505 StoreU(valueR, d, keys + writeL);
506 }
507 BlendedStore(valueR, FirstN(d, i - writeL), d, keys + writeL);
508 return false;
509 }
510 StoreU(valueL, d, keys + writeL);
511 writeL += CountTrue(d, eqL);
512 }
513
514 // Final vector, masked comparison (no effect if i == num)
515 const size_t remaining = num - i;
516 SafeCopyN(remaining, d, keys + i, buf);
517 const Vec<D> v = Load(d, buf);
518 const Mask<D> valid = FirstN(d, remaining);
519 const Mask<D> eqL = And(st.EqualKeys(d, v, valueL), valid);
520 const Mask<D> eqR = st.EqualKeys(d, v, valueR);
521 // Invalid lanes are considered equal.
522 const Mask<D> eq = Or(Or(eqL, eqR), Not(valid));
523 // At least one other value present; will require a regular partition.
524 if (HWY_UNLIKELY(!AllTrue(d, eq))) {
525 const size_t lane = FindKnownFirstTrue(d, Not(eq));
526 third = st.SetKey(d, keys + i + lane);
527 if (VQSORT_PRINT >= 2) {
528 fprintf(stderr, "found 3rd value at partial vec %zu; writeL %zu\n", i,
529 writeL);
530 }
531 // 'Undo' what we did by filling the remainder of what we read with R.
532 for (; writeL + N <= i; writeL += N) {
533 StoreU(valueR, d, keys + writeL);
534 }
535 BlendedStore(valueR, FirstN(d, i - writeL), d, keys + writeL);
536 return false;
537 }
538 BlendedStore(valueL, valid, d, keys + writeL);
539 writeL += CountTrue(d, eqL);
540
541 // Fill right side
542 i = writeL;
543 for (; i + N <= num; i += N) {
544 StoreU(valueR, d, keys + i);
545 }
546 BlendedStore(valueR, FirstN(d, num - i), d, keys + i);
547
548 if (VQSORT_PRINT >= 2) {
549 fprintf(stderr, "Successful MaybePartitionTwoValue\n");
550 }
551 return true;
552}
553
554// Same as above, except that the pivot equals valueR, so scan right to left.
555template <class D, class Traits, typename T>
556HWY_INLINE bool MaybePartitionTwoValueR(D d, Traits st, T* HWY_RESTRICT keys,
557 size_t num, const Vec<D> valueL,
558 const Vec<D> valueR, Vec<D>& third,
559 T* HWY_RESTRICT buf) {
560 const size_t N = Lanes(d);
561
562 HWY_DASSERT(num >= N);
563 size_t pos = num - N; // current read/write position
564 size_t countR = 0; // number of valueR found
565
566 // For whole vectors, in descending address order: as long as all lanes are
567 // equal to L or R, overwrite with valueR. This is faster than counting, then
568 // filling both L and R. Loop terminates after unsigned wraparound.
569 for (; pos < num; pos -= N) {
570 const Vec<D> v = LoadU(d, keys + pos);
571 // It is not clear how to apply OrXor here - that can check if *both*
572 // comparisons are true, but here we want *either*. Comparing the unsigned
573 // min of differences to zero works, but is expensive for u64 prior to AVX3.
574 const Mask<D> eqL = st.EqualKeys(d, v, valueL);
575 const Mask<D> eqR = st.EqualKeys(d, v, valueR);
576 // If there is a third value, stop and undo what we've done. On AVX-512,
577 // Or + AllTrue are folded into a single kortest, but only if we are
578 // careful with the FindKnownFirstTrue argument - see prior comment on that.
579 if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) {
580 const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR));
581 third = st.SetKey(d, keys + pos + lane);
582 if (VQSORT_PRINT >= 2) {
583 fprintf(stderr, "found 3rd value at vec %zu; countR %zu\n", pos,
584 countR);
585 MaybePrintVector(d, "third", third, 0, st.LanesPerKey());
586 }
587 pos += N; // rewind: we haven't yet committed changes in this iteration.
588 // We have filled [pos, num) with R, but only countR of them should have
589 // been written. Rewrite [pos, num - countR) to L.
590 HWY_DASSERT(countR <= num - pos);
591 const size_t endL = num - countR;
592 for (; pos + N <= endL; pos += N) {
593 StoreU(valueL, d, keys + pos);
594 }
595 BlendedStore(valueL, FirstN(d, endL - pos), d, keys + pos);
596 return false;
597 }
598 StoreU(valueR, d, keys + pos);
599 countR += CountTrue(d, eqR);
600 }
601
602 // Final partial (or empty) vector, masked comparison.
603 const size_t remaining = pos + N;
604 HWY_DASSERT(remaining <= N);
605 const Vec<D> v = LoadU(d, keys); // Safe because num >= N.
606 const Mask<D> valid = FirstN(d, remaining);
607 const Mask<D> eqL = st.EqualKeys(d, v, valueL);
608 const Mask<D> eqR = And(st.EqualKeys(d, v, valueR), valid);
609 // Invalid lanes are considered equal.
610 const Mask<D> eq = Or(Or(eqL, eqR), Not(valid));
611 // At least one other value present; will require a regular partition.
612 if (HWY_UNLIKELY(!AllTrue(d, eq))) {
613 const size_t lane = FindKnownFirstTrue(d, Not(eq));
614 third = st.SetKey(d, keys + lane);
615 if (VQSORT_PRINT >= 2) {
616 fprintf(stderr, "found 3rd value at partial vec %zu; writeR %zu\n", pos,
617 countR);
618 MaybePrintVector(d, "third", third, 0, st.LanesPerKey());
619 }
620 pos += N; // rewind: we haven't yet committed changes in this iteration.
621 // We have filled [pos, num) with R, but only countR of them should have
622 // been written. Rewrite [pos, num - countR) to L.
623 HWY_DASSERT(countR <= num - pos);
624 const size_t endL = num - countR;
625 for (; pos + N <= endL; pos += N) {
626 StoreU(valueL, d, keys + pos);
627 }
628 BlendedStore(valueL, FirstN(d, endL - pos), d, keys + pos);
629 return false;
630 }
631 const size_t lastR = CountTrue(d, eqR);
632 countR += lastR;
633
634 // First finish writing valueR - [0, N) lanes were not yet written.
635 StoreU(valueR, d, keys); // Safe because num >= N.
636
637 // Fill left side (ascending order for clarity)
638 const size_t endL = num - countR;
639 size_t i = 0;
640 for (; i + N <= endL; i += N) {
641 StoreU(valueL, d, keys + i);
642 }
643 Store(valueL, d, buf);
644 SafeCopyN(endL - i, d, buf, keys + i); // avoids asan overrun
645
646 if (VQSORT_PRINT >= 2) {
647 fprintf(stderr,
648 "MaybePartitionTwoValueR countR %zu pos %zu i %zu endL %zu\n",
649 countR, pos, i, endL);
650 }
651
652 return true;
653}
654
655// `idx_second` is `first_mismatch` from `AllEqual` and thus the index of the
656// second key. This is the first path into `MaybePartitionTwoValue`, called
657// when all samples are equal. Returns false if there are at least a third
658// value and sets `third`. Otherwise, partitions the array and returns true.
659template <class D, class Traits, typename T>
660HWY_INLINE bool PartitionIfTwoKeys(D d, Traits st, const Vec<D> pivot,
661 T* HWY_RESTRICT keys, size_t num,
662 const size_t idx_second, const Vec<D> second,
663 Vec<D>& third, T* HWY_RESTRICT buf) {
664 // True if second comes before pivot.
665 const bool is_pivotR = AllFalse(d, st.Compare(d, pivot, second));
666 if (VQSORT_PRINT >= 1) {
667 fprintf(stderr, "Samples all equal, diff at %zu, isPivotR %d\n", idx_second,
668 is_pivotR);
669 }
670 HWY_DASSERT(AllFalse(d, st.EqualKeys(d, second, pivot)));
671
672 // If pivot is R, we scan backwards over the entire array. Otherwise,
673 // we already scanned up to idx_second and can leave those in place.
674 return is_pivotR ? MaybePartitionTwoValueR(d, st, keys, num, second, pivot,
675 third, buf)
676 : MaybePartitionTwoValue(d, st, keys + idx_second,
677 num - idx_second, pivot, second,
678 third, buf);
679}
680
681// Second path into `MaybePartitionTwoValue`, called when not all samples are
682// equal. `samples` is sorted.
683template <class D, class Traits, typename T>
684HWY_INLINE bool PartitionIfTwoSamples(D d, Traits st, T* HWY_RESTRICT keys,
685 size_t num, T* HWY_RESTRICT samples) {
686 constexpr size_t kSampleLanes = 3 * 64 / sizeof(T);
687 constexpr size_t N1 = st.LanesPerKey();
688 const Vec<D> valueL = st.SetKey(d, samples);
689 const Vec<D> valueR = st.SetKey(d, samples + kSampleLanes - N1);
690 HWY_DASSERT(AllTrue(d, st.Compare(d, valueL, valueR)));
691 HWY_DASSERT(AllFalse(d, st.EqualKeys(d, valueL, valueR)));
692 const Vec<D> prev = st.PrevValue(d, valueR);
693 // If the sample has more than two values, then the keys have at least that
694 // many, and thus this special case is inapplicable.
695 if (HWY_UNLIKELY(!AllTrue(d, st.EqualKeys(d, valueL, prev)))) {
696 return false;
697 }
698
699 // Must not overwrite samples because if this returns false, caller wants to
700 // read the original samples again.
701 T* HWY_RESTRICT buf = samples + kSampleLanes;
702 Vec<D> third; // unused
703 return MaybePartitionTwoValue(d, st, keys, num, valueL, valueR, third, buf);
704}
705
706// ------------------------------ Pivot sampling
707
708template <class Traits, class V>
709HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) {
710 const DFromV<V> d;
711 // Slightly faster for 128-bit, apparently because not serially dependent.
712 if (st.Is128()) {
713 // Median = XOR-sum 'minus' the first and last. Calling First twice is
714 // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR.
715 const auto sum = Xor(Xor(v0, v1), v2);
716 const auto first = st.First(d, st.First(d, v0, v1), v2);
717 const auto last = st.Last(d, st.Last(d, v0, v1), v2);
718 return Xor(Xor(sum, first), last);
719 }
720 st.Sort2(d, v0, v2);
721 v1 = st.Last(d, v0, v1);
722 v1 = st.First(d, v1, v2);
723 return v1;
724}
725
726#if VQSORT_SECURE_RNG
727using Generator = absl::BitGen;
728#else
729// Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028
730#pragma pack(push, 1)
731class Generator {
732 public:
733 Generator(const void* heap, size_t num) {
734 Sorter::Fill24Bytes(heap, num, &a_);
735 k_ = 1; // stream index: must be odd
736 }
737
738 explicit Generator(uint64_t seed) {
739 a_ = b_ = w_ = seed;
740 k_ = 1;
741 }
742
743 uint64_t operator()() {
744 const uint64_t b = b_;
745 w_ += k_;
746 const uint64_t next = a_ ^ w_;
747 a_ = (b + (b << 3)) ^ (b >> 11);
748 const uint64_t rot = (b << 24) | (b >> 40);
749 b_ = rot + next;
750 return next;
751 }
752
753 private:
754 uint64_t a_;
755 uint64_t b_;
756 uint64_t w_;
757 uint64_t k_; // increment
758};
759#pragma pack(pop)
760
761#endif // !VQSORT_SECURE_RNG
762
763// Returns slightly biased random index of a chunk in [0, num_chunks).
764// See https://www.pcg-random.org/posts/bounded-rands.html.
765HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) {
766 const uint64_t chunk_index = (static_cast<uint64_t>(bits) * num_chunks) >> 32;
767 HWY_DASSERT(chunk_index < num_chunks);
768 return static_cast<size_t>(chunk_index);
769}
770
771// Writes samples from `keys[0, num)` into `buf`.
772template <class D, class Traits, typename T>
773HWY_INLINE void DrawSamples(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
774 T* HWY_RESTRICT buf, Generator& rng) {
775 using V = decltype(Zero(d));
776 const size_t N = Lanes(d);
777
778 // Power of two
779 constexpr size_t kLanesPerChunk = Constants::LanesPerChunk(sizeof(T));
780
781 // Align start of keys to chunks. We always have at least 2 chunks because the
782 // base case would have handled anything up to 16 vectors, i.e. >= 4 chunks.
783 HWY_DASSERT(num >= 2 * kLanesPerChunk);
784 const size_t misalign =
785 (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (kLanesPerChunk - 1);
786 if (misalign != 0) {
787 const size_t consume = kLanesPerChunk - misalign;
788 keys += consume;
789 num -= consume;
790 }
791
792 // Generate enough random bits for 9 uint32
793 uint64_t* bits64 = reinterpret_cast<uint64_t*>(buf);
794 for (size_t i = 0; i < 5; ++i) {
795 bits64[i] = rng();
796 }
797 const uint32_t* bits = reinterpret_cast<const uint32_t*>(buf);
798
799 const size_t num_chunks64 = num / kLanesPerChunk;
800 // Clamp to uint32 for RandomChunkIndex
801 const uint32_t num_chunks =
802 static_cast<uint32_t>(HWY_MIN(num_chunks64, 0xFFFFFFFFull));
803
804 const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) * kLanesPerChunk;
805 const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) * kLanesPerChunk;
806 const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) * kLanesPerChunk;
807 const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) * kLanesPerChunk;
808 const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) * kLanesPerChunk;
809 const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) * kLanesPerChunk;
810 const size_t offset6 = RandomChunkIndex(num_chunks, bits[6]) * kLanesPerChunk;
811 const size_t offset7 = RandomChunkIndex(num_chunks, bits[7]) * kLanesPerChunk;
812 const size_t offset8 = RandomChunkIndex(num_chunks, bits[8]) * kLanesPerChunk;
813 for (size_t i = 0; i < kLanesPerChunk; i += N) {
814 const V v0 = Load(d, keys + offset0 + i);
815 const V v1 = Load(d, keys + offset1 + i);
816 const V v2 = Load(d, keys + offset2 + i);
817 const V medians0 = MedianOf3(st, v0, v1, v2);
818 Store(medians0, d, buf + i);
819
820 const V v3 = Load(d, keys + offset3 + i);
821 const V v4 = Load(d, keys + offset4 + i);
822 const V v5 = Load(d, keys + offset5 + i);
823 const V medians1 = MedianOf3(st, v3, v4, v5);
824 Store(medians1, d, buf + i + kLanesPerChunk);
825
826 const V v6 = Load(d, keys + offset6 + i);
827 const V v7 = Load(d, keys + offset7 + i);
828 const V v8 = Load(d, keys + offset8 + i);
829 const V medians2 = MedianOf3(st, v6, v7, v8);
830 Store(medians2, d, buf + i + kLanesPerChunk * 2);
831 }
832}
833
834// For detecting inputs where (almost) all keys are equal.
835template <class D, class Traits>
836HWY_INLINE bool UnsortedSampleEqual(D d, Traits st,
837 const TFromD<D>* HWY_RESTRICT samples) {
838 constexpr size_t kSampleLanes = 3 * 64 / sizeof(TFromD<D>);
839 const size_t N = Lanes(d);
840 using V = Vec<D>;
841
842 const V first = st.SetKey(d, samples);
843 // OR of XOR-difference may be faster than comparison.
844 V diff = Zero(d);
845 size_t i = 0;
846 for (; i + N <= kSampleLanes; i += N) {
847 const V v = Load(d, samples + i);
848 diff = OrXor(diff, first, v);
849 }
850 // Remainder, if any.
851 const V v = Load(d, samples + i);
852 const auto valid = FirstN(d, kSampleLanes - i);
853 diff = IfThenElse(valid, OrXor(diff, first, v), diff);
854
855 return st.NoKeyDifference(d, diff);
856}
857
858template <class D, class Traits, typename T>
859HWY_INLINE void SortSamples(D d, Traits st, T* HWY_RESTRICT buf) {
860 // buf contains 192 bytes, so 16 128-bit vectors are necessary and sufficient.
861 constexpr size_t kSampleLanes = 3 * 64 / sizeof(T);
862 const CappedTag<T, 16 / sizeof(T)> d128;
863 const size_t N128 = Lanes(d128);
864 constexpr size_t kCols = HWY_MIN(16 / sizeof(T), Constants::kMaxCols);
865 constexpr size_t kBytes = kCols * Constants::kMaxRows * sizeof(T);
866 static_assert(192 <= kBytes, "");
867 // Fill with padding - last in sort order.
868 const auto kPadding = st.LastValue(d128);
869 // Initialize an extra vector because SortingNetwork loads full vectors,
870 // which may exceed cols*kMaxRows.
871 for (size_t i = kSampleLanes; i <= kBytes / sizeof(T); i += N128) {
872 StoreU(kPadding, d128, buf + i);
873 }
874
875 SortingNetwork(st, buf, kCols);
876
877 if (VQSORT_PRINT >= 2) {
878 const size_t N = Lanes(d);
879 fprintf(stderr, "Samples:\n");
880 for (size_t i = 0; i < kSampleLanes; i += N) {
881 MaybePrintVector(d, "", Load(d, buf + i), 0, N);
882 }
883 }
884}
885
886// ------------------------------ Pivot selection
887
888enum class PivotResult {
889 kDone, // stop without partitioning (all equal, or two-value partition)
890 kNormal, // partition and recurse left and right
891 kIsFirst, // partition but skip left recursion
892 kWasLast, // partition but skip right recursion
893};
894
895HWY_INLINE const char* PivotResultString(PivotResult result) {
896 switch (result) {
897 case PivotResult::kDone:
898 return "done";
899 case PivotResult::kNormal:
900 return "normal";
901 case PivotResult::kIsFirst:
902 return "first";
903 case PivotResult::kWasLast:
904 return "last";
905 }
906 return "unknown";
907}
908
909template <class Traits, typename T>
910HWY_INLINE size_t PivotRank(Traits st, const T* HWY_RESTRICT samples) {
911 constexpr size_t kSampleLanes = 3 * 64 / sizeof(T);
912 constexpr size_t N1 = st.LanesPerKey();
913
914 constexpr size_t kRankMid = kSampleLanes / 2;
915 static_assert(kRankMid % N1 == 0, "Mid is not an aligned key");
916
917 // Find the previous value not equal to the median.
918 size_t rank_prev = kRankMid - N1;
919 for (; st.Equal1(samples + rank_prev, samples + kRankMid); rank_prev -= N1) {
920 // All previous samples are equal to the median.
921 if (rank_prev == 0) return 0;
922 }
923
924 size_t rank_next = rank_prev + N1;
925 for (; st.Equal1(samples + rank_next, samples + kRankMid); rank_next += N1) {
926 // The median is also the largest sample. If it is also the largest key,
927 // we'd end up with an empty right partition, so choose the previous key.
928 if (rank_next == kSampleLanes - N1) return rank_prev;
929 }
930
931 // If we choose the median as pivot, the ratio of keys ending in the left
932 // partition will likely be rank_next/kSampleLanes (if the sample is
933 // representative). This is because equal-to-pivot values also land in the
934 // left - it's infeasible to do an in-place vectorized 3-way partition.
935 // Check whether prev would lead to a more balanced partition.
936 const size_t excess_if_median = rank_next - kRankMid;
937 const size_t excess_if_prev = kRankMid - rank_prev;
938 return excess_if_median < excess_if_prev ? kRankMid : rank_prev;
939}
940
941// Returns pivot chosen from `samples`. It will never be the largest key
942// (thus the right partition will never be empty).
943template <class D, class Traits, typename T>
944HWY_INLINE Vec<D> ChoosePivotByRank(D d, Traits st,
945 const T* HWY_RESTRICT samples) {
946 const size_t pivot_rank = PivotRank(st, samples);
947 const Vec<D> pivot = st.SetKey(d, samples + pivot_rank);
948 if (VQSORT_PRINT >= 2) {
949 fprintf(stderr, " Pivot rank %zu = %f\n", pivot_rank,
950 static_cast<double>(GetLane(pivot)));
951 }
952 // Verify pivot is not equal to the last sample.
953 constexpr size_t kSampleLanes = 3 * 64 / sizeof(T);
954 constexpr size_t N1 = st.LanesPerKey();
955 const Vec<D> last = st.SetKey(d, samples + kSampleLanes - N1);
956 const bool all_neq = AllTrue(d, st.NotEqualKeys(d, pivot, last));
957 (void)all_neq;
958 HWY_DASSERT(all_neq);
959 return pivot;
960}
961
962// Returns true if all keys equal `pivot`, otherwise returns false and sets
963// `*first_mismatch' to the index of the first differing key.
964template <class D, class Traits, typename T>
965HWY_INLINE bool AllEqual(D d, Traits st, const Vec<D> pivot,
966 const T* HWY_RESTRICT keys, size_t num,
967 size_t* HWY_RESTRICT first_mismatch) {
968 const size_t N = Lanes(d);
969 // Ensures we can use overlapping loads for the tail; see HandleSpecialCases.
970 HWY_DASSERT(num >= N);
971 const Vec<D> zero = Zero(d);
972
973 // Vector-align keys + i.
974 const size_t misalign =
975 (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (N - 1);
976 HWY_DASSERT(misalign % st.LanesPerKey() == 0);
977 const size_t consume = N - misalign;
978 {
979 const Vec<D> v = LoadU(d, keys);
980 // Only check masked lanes; consider others to be equal.
981 const Mask<D> diff = And(FirstN(d, consume), st.NotEqualKeys(d, v, pivot));
982 if (HWY_UNLIKELY(!AllFalse(d, diff))) {
983 const size_t lane = FindKnownFirstTrue(d, diff);
984 *first_mismatch = lane;
985 return false;
986 }
987 }
988 size_t i = consume;
989 HWY_DASSERT(((reinterpret_cast<uintptr_t>(keys + i) / sizeof(T)) & (N - 1)) ==
990 0);
991
992 // Sticky bits registering any difference between `keys` and the first key.
993 // We use vector XOR because it may be cheaper than comparisons, especially
994 // for 128-bit. 2x unrolled for more ILP.
995 Vec<D> diff0 = zero;
996 Vec<D> diff1 = zero;
997
998 // We want to stop once a difference has been found, but without slowing
999 // down the loop by comparing during each iteration. The compromise is to
1000 // compare after a 'group', which consists of kLoops times two vectors.
1001 constexpr size_t kLoops = 8;
1002 const size_t lanes_per_group = kLoops * 2 * N;
1003
1004 for (; i + lanes_per_group <= num; i += lanes_per_group) {
1006 for (size_t loop = 0; loop < kLoops; ++loop) {
1007 const Vec<D> v0 = Load(d, keys + i + loop * 2 * N);
1008 const Vec<D> v1 = Load(d, keys + i + loop * 2 * N + N);
1009 diff0 = OrXor(diff0, v0, pivot);
1010 diff1 = OrXor(diff1, v1, pivot);
1011 }
1012
1013 // If there was a difference in the entire group:
1014 if (HWY_UNLIKELY(!st.NoKeyDifference(d, Or(diff0, diff1)))) {
1015 // .. then loop until the first one, with termination guarantee.
1016 for (;; i += N) {
1017 const Vec<D> v = Load(d, keys + i);
1018 const Mask<D> diff = st.NotEqualKeys(d, v, pivot);
1019 if (HWY_UNLIKELY(!AllFalse(d, diff))) {
1020 const size_t lane = FindKnownFirstTrue(d, diff);
1021 *first_mismatch = i + lane;
1022 return false;
1023 }
1024 }
1025 }
1026 }
1027
1028 // Whole vectors, no unrolling, compare directly
1029 for (; i + N <= num; i += N) {
1030 const Vec<D> v = Load(d, keys + i);
1031 const Mask<D> diff = st.NotEqualKeys(d, v, pivot);
1032 if (HWY_UNLIKELY(!AllFalse(d, diff))) {
1033 const size_t lane = FindKnownFirstTrue(d, diff);
1034 *first_mismatch = i + lane;
1035 return false;
1036 }
1037 }
1038 // Always re-check the last (unaligned) vector to reduce branching.
1039 i = num - N;
1040 const Vec<D> v = LoadU(d, keys + i);
1041 const Mask<D> diff = st.NotEqualKeys(d, v, pivot);
1042 if (HWY_UNLIKELY(!AllFalse(d, diff))) {
1043 const size_t lane = FindKnownFirstTrue(d, diff);
1044 *first_mismatch = i + lane;
1045 return false;
1046 }
1047
1048 if (VQSORT_PRINT >= 1) {
1049 fprintf(stderr, "All keys equal\n");
1050 }
1051 return true; // all equal
1052}
1053
1054// Called from 'two locations', but only one is active (IsKV is constexpr).
1055template <class D, class Traits, typename T>
1056HWY_INLINE bool ExistsAnyBefore(D d, Traits st, const T* HWY_RESTRICT keys,
1057 size_t num, const Vec<D> pivot) {
1058 const size_t N = Lanes(d);
1059 HWY_DASSERT(num >= N); // See HandleSpecialCases
1060
1061 if (VQSORT_PRINT >= 2) {
1062 fprintf(stderr, "Scanning for before\n");
1063 }
1064
1065 size_t i = 0;
1066
1067 constexpr size_t kLoops = 16;
1068 const size_t lanes_per_group = kLoops * N;
1069
1070 Vec<D> first = pivot;
1071
1072 // Whole group, unrolled
1073 for (; i + lanes_per_group <= num; i += lanes_per_group) {
1075 for (size_t loop = 0; loop < kLoops; ++loop) {
1076 const Vec<D> curr = LoadU(d, keys + i + loop * N);
1077 first = st.First(d, first, curr);
1078 }
1079
1080 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, first, pivot)))) {
1081 if (VQSORT_PRINT >= 2) {
1082 fprintf(stderr, "Stopped scanning at end of group %zu\n",
1083 i + lanes_per_group);
1084 }
1085 return true;
1086 }
1087 }
1088 // Whole vectors, no unrolling
1089 for (; i + N <= num; i += N) {
1090 const Vec<D> curr = LoadU(d, keys + i);
1091 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) {
1092 if (VQSORT_PRINT >= 2) {
1093 fprintf(stderr, "Stopped scanning at %zu\n", i);
1094 }
1095 return true;
1096 }
1097 }
1098 // If there are remainders, re-check the last whole vector.
1099 if (HWY_LIKELY(i != num)) {
1100 const Vec<D> curr = LoadU(d, keys + num - N);
1101 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) {
1102 if (VQSORT_PRINT >= 2) {
1103 fprintf(stderr, "Stopped scanning at last %zu\n", num - N);
1104 }
1105 return true;
1106 }
1107 }
1108
1109 return false; // pivot is the first
1110}
1111
1112// Called from 'two locations', but only one is active (IsKV is constexpr).
1113template <class D, class Traits, typename T>
1114HWY_INLINE bool ExistsAnyAfter(D d, Traits st, const T* HWY_RESTRICT keys,
1115 size_t num, const Vec<D> pivot) {
1116 const size_t N = Lanes(d);
1117 HWY_DASSERT(num >= N); // See HandleSpecialCases
1118
1119 if (VQSORT_PRINT >= 2) {
1120 fprintf(stderr, "Scanning for after\n");
1121 }
1122
1123 size_t i = 0;
1124
1125 constexpr size_t kLoops = 16;
1126 const size_t lanes_per_group = kLoops * N;
1127
1128 Vec<D> last = pivot;
1129
1130 // Whole group, unrolled
1131 for (; i + lanes_per_group <= num; i += lanes_per_group) {
1133 for (size_t loop = 0; loop < kLoops; ++loop) {
1134 const Vec<D> curr = LoadU(d, keys + i + loop * N);
1135 last = st.Last(d, last, curr);
1136 }
1137
1138 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, last)))) {
1139 if (VQSORT_PRINT >= 2) {
1140 fprintf(stderr, "Stopped scanning at end of group %zu\n",
1141 i + lanes_per_group);
1142 }
1143 return true;
1144 }
1145 }
1146 // Whole vectors, no unrolling
1147 for (; i + N <= num; i += N) {
1148 const Vec<D> curr = LoadU(d, keys + i);
1149 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) {
1150 if (VQSORT_PRINT >= 2) {
1151 fprintf(stderr, "Stopped scanning at %zu\n", i);
1152 }
1153 return true;
1154 }
1155 }
1156 // If there are remainders, re-check the last whole vector.
1157 if (HWY_LIKELY(i != num)) {
1158 const Vec<D> curr = LoadU(d, keys + num - N);
1159 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) {
1160 if (VQSORT_PRINT >= 2) {
1161 fprintf(stderr, "Stopped scanning at last %zu\n", num - N);
1162 }
1163 return true;
1164 }
1165 }
1166
1167 return false; // pivot is the last
1168}
1169
1170// Returns pivot chosen from `keys[0, num)`. It will never be the largest key
1171// (thus the right partition will never be empty).
1172template <class D, class Traits, typename T>
1173HWY_INLINE Vec<D> ChoosePivotForEqualSamples(D d, Traits st,
1174 T* HWY_RESTRICT keys, size_t num,
1175 T* HWY_RESTRICT samples,
1176 Vec<D> second, Vec<D> third,
1177 PivotResult& result) {
1178 const Vec<D> pivot = st.SetKey(d, samples); // the single unique sample
1179
1180 // Early out for mostly-0 arrays, where pivot is often FirstValue.
1181 if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.FirstValue(d))))) {
1182 result = PivotResult::kIsFirst;
1183 return pivot;
1184 }
1185 if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.LastValue(d))))) {
1186 result = PivotResult::kWasLast;
1187 return st.PrevValue(d, pivot);
1188 }
1189
1190 // If key-value, we didn't run PartitionIfTwo* and thus `third` is unknown and
1191 // cannot be used.
1192 if (st.IsKV()) {
1193 // If true, pivot is either middle or last.
1194 const bool before = !AllFalse(d, st.Compare(d, second, pivot));
1195 if (HWY_UNLIKELY(before)) {
1196 // Not last, so middle.
1197 if (HWY_UNLIKELY(ExistsAnyAfter(d, st, keys, num, pivot))) {
1198 result = PivotResult::kNormal;
1199 return pivot;
1200 }
1201
1202 // We didn't find anything after pivot, so it is the last. Because keys
1203 // equal to the pivot go to the left partition, the right partition would
1204 // be empty and Partition will not have changed anything. Instead use the
1205 // previous value in sort order, which is not necessarily an actual key.
1206 result = PivotResult::kWasLast;
1207 return st.PrevValue(d, pivot);
1208 }
1209
1210 // Otherwise, pivot is first or middle. Rule out it being first:
1211 if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) {
1212 result = PivotResult::kNormal;
1213 return pivot;
1214 }
1215 // It is first: fall through to shared code below.
1216 } else {
1217 // Check if pivot is between two known values. If so, it is not the first
1218 // nor the last and we can avoid scanning.
1219 st.Sort2(d, second, third);
1220 HWY_DASSERT(AllTrue(d, st.Compare(d, second, third)));
1221 const bool before = !AllFalse(d, st.Compare(d, second, pivot));
1222 const bool after = !AllFalse(d, st.Compare(d, pivot, third));
1223 // Only reached if there are three keys, which means pivot is either first,
1224 // last, or in between. Thus there is another key that comes before or
1225 // after.
1226 HWY_DASSERT(before || after);
1227 if (HWY_UNLIKELY(before)) {
1228 // Neither first nor last.
1229 if (HWY_UNLIKELY(after || ExistsAnyAfter(d, st, keys, num, pivot))) {
1230 result = PivotResult::kNormal;
1231 return pivot;
1232 }
1233
1234 // We didn't find anything after pivot, so it is the last. Because keys
1235 // equal to the pivot go to the left partition, the right partition would
1236 // be empty and Partition will not have changed anything. Instead use the
1237 // previous value in sort order, which is not necessarily an actual key.
1238 result = PivotResult::kWasLast;
1239 return st.PrevValue(d, pivot);
1240 }
1241
1242 // Has after, and we found one before: in the middle.
1243 if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) {
1244 result = PivotResult::kNormal;
1245 return pivot;
1246 }
1247 }
1248
1249 // Pivot is first. We could consider a special partition mode that only
1250 // reads from and writes to the right side, and later fills in the left
1251 // side, which we know is equal to the pivot. However, that leads to more
1252 // cache misses if the array is large, and doesn't save much, hence is a
1253 // net loss.
1254 result = PivotResult::kIsFirst;
1255 return pivot;
1256}
1257
1258// ------------------------------ Quicksort recursion
1259
1260template <class D, class Traits, typename T>
1261HWY_NOINLINE void PrintMinMax(D d, Traits st, const T* HWY_RESTRICT keys,
1262 size_t num, T* HWY_RESTRICT buf) {
1263 if (VQSORT_PRINT >= 2) {
1264 const size_t N = Lanes(d);
1265 if (num < N) return;
1266
1267 Vec<D> first = st.LastValue(d);
1268 Vec<D> last = st.FirstValue(d);
1269
1270 size_t i = 0;
1271 for (; i + N <= num; i += N) {
1272 const Vec<D> v = LoadU(d, keys + i);
1273 first = st.First(d, v, first);
1274 last = st.Last(d, v, last);
1275 }
1276 if (HWY_LIKELY(i != num)) {
1277 HWY_DASSERT(num >= N); // See HandleSpecialCases
1278 const Vec<D> v = LoadU(d, keys + num - N);
1279 first = st.First(d, v, first);
1280 last = st.Last(d, v, last);
1281 }
1282
1283 first = st.FirstOfLanes(d, first, buf);
1284 last = st.LastOfLanes(d, last, buf);
1285 MaybePrintVector(d, "first", first, 0, st.LanesPerKey());
1286 MaybePrintVector(d, "last", last, 0, st.LanesPerKey());
1287 }
1288}
1289
1290// keys_end is the end of the entire user input, not just the current subarray
1291// [keys, keys + num).
1292template <class D, class Traits, typename T>
1293HWY_NOINLINE void Recurse(D d, Traits st, T* HWY_RESTRICT keys,
1294 T* HWY_RESTRICT keys_end, const size_t num,
1295 T* HWY_RESTRICT buf, Generator& rng,
1296 const size_t remaining_levels) {
1297 HWY_DASSERT(num != 0);
1298
1300 BaseCase(d, st, keys, keys_end, num, buf);
1301 return;
1302 }
1303
1304 // Move after BaseCase so we skip printing for small subarrays.
1305 if (VQSORT_PRINT >= 1) {
1306 fprintf(stderr, "\n\n=== Recurse depth=%zu len=%zu\n", remaining_levels,
1307 num);
1308 PrintMinMax(d, st, keys, num, buf);
1309 }
1310
1311 DrawSamples(d, st, keys, num, buf, rng);
1312
1313 Vec<D> pivot;
1314 PivotResult result = PivotResult::kNormal;
1315 if (HWY_UNLIKELY(UnsortedSampleEqual(d, st, buf))) {
1316 pivot = st.SetKey(d, buf);
1317 size_t idx_second = 0;
1318 if (HWY_UNLIKELY(AllEqual(d, st, pivot, keys, num, &idx_second))) {
1319 return;
1320 }
1321 HWY_DASSERT(idx_second % st.LanesPerKey() == 0);
1322 // Must capture the value before PartitionIfTwoKeys may overwrite it.
1323 const Vec<D> second = st.SetKey(d, keys + idx_second);
1324 MaybePrintVector(d, "pivot", pivot, 0, st.LanesPerKey());
1325 MaybePrintVector(d, "second", second, 0, st.LanesPerKey());
1326
1327 Vec<D> third;
1328 // Not supported for key-value types because two 'keys' may be equivalent
1329 // but not interchangeable (their values may differ).
1330 if (HWY_UNLIKELY(!st.IsKV() &&
1331 PartitionIfTwoKeys(d, st, pivot, keys, num, idx_second,
1332 second, third, buf))) {
1333 return; // Done, skip recursion because each side has all-equal keys.
1334 }
1335
1336 // We can no longer start scanning from idx_second because
1337 // PartitionIfTwoKeys may have reordered keys.
1338 pivot = ChoosePivotForEqualSamples(d, st, keys, num, buf, second, third,
1339 result);
1340 // If kNormal, `pivot` is very common but not the first/last. It is
1341 // tempting to do a 3-way partition (to avoid moving the =pivot keys a
1342 // second time), but that is a net loss due to the extra comparisons.
1343 } else {
1344 SortSamples(d, st, buf);
1345
1346 // Not supported for key-value types because two 'keys' may be equivalent
1347 // but not interchangeable (their values may differ).
1348 if (HWY_UNLIKELY(!st.IsKV() &&
1349 PartitionIfTwoSamples(d, st, keys, num, buf))) {
1350 return;
1351 }
1352
1353 pivot = ChoosePivotByRank(d, st, buf);
1354 }
1355
1356 // Too many recursions. This is unlikely to happen because we select pivots
1357 // from large (though still O(1)) samples.
1358 if (HWY_UNLIKELY(remaining_levels == 0)) {
1359 if (VQSORT_PRINT >= 1) {
1360 fprintf(stderr, "HeapSort reached, size=%zu\n", num);
1361 }
1362 HeapSort(st, keys, num); // Slow but N*logN.
1363 return;
1364 }
1365
1366 const size_t bound = Partition(d, st, keys, num, pivot, buf);
1367 if (VQSORT_PRINT >= 2) {
1368 fprintf(stderr, "bound %zu num %zu result %s\n", bound, num,
1369 PivotResultString(result));
1370 }
1371 // The left partition is not empty because the pivot is one of the keys
1372 // (unless kWasLast, in which case the pivot is PrevValue, but we still
1373 // have at least one value <= pivot because AllEqual ruled out the case of
1374 // only one unique value, and there is exactly one value after pivot).
1375 HWY_DASSERT(bound != 0);
1376 // ChoosePivot* ensure pivot != last, so the right partition is never empty.
1377 HWY_DASSERT(bound != num);
1378
1379 if (HWY_LIKELY(result != PivotResult::kIsFirst)) {
1380 Recurse(d, st, keys, keys_end, bound, buf, rng, remaining_levels - 1);
1381 }
1382 if (HWY_LIKELY(result != PivotResult::kWasLast)) {
1383 Recurse(d, st, keys + bound, keys_end, num - bound, buf, rng,
1384 remaining_levels - 1);
1385 }
1386}
1387
1388// Returns true if sorting is finished.
1389template <class D, class Traits, typename T>
1390HWY_INLINE bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys,
1391 size_t num) {
1392 const size_t N = Lanes(d);
1393 const size_t base_case_num = Constants::BaseCaseNum(N);
1394
1395 // 128-bit keys require vectors with at least two u64 lanes, which is always
1396 // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the
1397 // hardware vector width is less than 128bit / fraction.
1398 const bool partial_128 = !IsFull(d) && N < 2 && st.Is128();
1399 // Partition assumes its input is at least two vectors. If vectors are huge,
1400 // base_case_num may actually be smaller. If so, which is only possible on
1401 // RVV, pass a capped or partial d (LMUL < 1). Use HWY_MAX_BYTES instead of
1402 // HWY_LANES to account for the largest possible LMUL.
1403 constexpr bool kPotentiallyHuge =
1405 const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num);
1406 if (partial_128 || huge_vec) {
1407 if (VQSORT_PRINT >= 1) {
1408 fprintf(stderr, "WARNING: using slow HeapSort: partial %d huge %d\n",
1409 partial_128, huge_vec);
1410 }
1411 HeapSort(st, keys, num);
1412 return true;
1413 }
1414
1415 // Small arrays are already handled by Recurse.
1416
1417 // We could also check for already sorted/reverse/equal, but that's probably
1418 // counterproductive if vqsort is used as a base case.
1419
1420 return false; // not finished sorting
1421}
1422
1423#endif // VQSORT_ENABLED
1424} // namespace detail
1425
1426// Sorts `keys[0..num-1]` according to the order defined by `st.Compare`.
1427// In-place i.e. O(1) additional storage. Worst-case N*logN comparisons.
1428// Non-stable (order of equal keys may change), except for the common case where
1429// the upper bits of T are the key, and the lower bits are a sequential or at
1430// least unique ID.
1431// There is no upper limit on `num`, but note that pivots may be chosen by
1432// sampling only from the first 256 GiB.
1433//
1434// `d` is typically SortTag<T> (chooses between full and partial vectors).
1435// `st` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges
1436// differences in sort order and single-lane vs 128-bit keys.
1437template <class D, class Traits, typename T>
1438void Sort(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
1439 T* HWY_RESTRICT buf) {
1440 if (VQSORT_PRINT >= 1) {
1441 fprintf(stderr, "=============== Sort num %zu\n", num);
1442 }
1443
1444#if VQSORT_ENABLED || HWY_IDE
1445#if !HWY_HAVE_SCALABLE
1446 // On targets with fixed-size vectors, avoid _using_ the allocated memory.
1447 // We avoid (potentially expensive for small input sizes) allocations on
1448 // platforms where no targets are scalable. For 512-bit vectors, this fits on
1449 // the stack (several KiB).
1450 HWY_ALIGN T storage[SortConstants::BufNum<T>(HWY_LANES(T))] = {};
1451 static_assert(sizeof(storage) <= 8192, "Unexpectedly large, check size");
1452 buf = storage;
1453#endif // !HWY_HAVE_SCALABLE
1454
1455 if (detail::HandleSpecialCases(d, st, keys, num)) return;
1456
1457#if HWY_MAX_BYTES > 64
1458 // sorting_networks-inl and traits assume no more than 512 bit vectors.
1459 if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) {
1460 return Sort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, buf);
1461 }
1462#endif // HWY_MAX_BYTES > 64
1463
1464 detail::Generator rng(keys, num);
1465
1466 // Introspection: switch to worst-case N*logN heapsort after this many.
1467 const size_t max_levels = 2 * hwy::CeilLog2(num) + 4;
1468 detail::Recurse(d, st, keys, keys + num, num, buf, rng, max_levels);
1469#else
1470 (void)d;
1471 (void)buf;
1472 if (VQSORT_PRINT >= 1) {
1473 fprintf(stderr, "WARNING: using slow HeapSort because vqsort disabled\n");
1474 }
1475 return detail::HeapSort(st, keys, num);
1476#endif // VQSORT_ENABLED
1477}
1478
1479// NOLINTNEXTLINE(google-readability-namespace-comments)
1480} // namespace HWY_NAMESPACE
1481} // namespace hwy
1483
1484#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
#define HWY_MAX(a, b)
Definition base.h:135
#define HWY_RESTRICT
Definition base.h:64
#define HWY_NOINLINE
Definition base.h:72
#define HWY_MIN(a, b)
Definition base.h:134
#define HWY_INLINE
Definition base.h:70
#define HWY_DASSERT(condition)
Definition base.h:238
#define HWY_DEFAULT_UNROLL
Definition base.h:146
#define HWY_LIKELY(expr)
Definition base.h:75
#define HWY_UNLIKELY(expr)
Definition base.h:76
void SiftDown(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes, size_t start)
Definition vqsort-inl.h:99
HWY_INLINE void MaybePrintVector(D d, const char *label, Vec< D > v, size_t start=0, size_t max_lanes=16)
Definition vqsort-inl.h:83
HWY_INLINE Mask128< T, N > ExclusiveNeither(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition x86_128-inl.h:963
HWY_INLINE bool AllTrue(hwy::SizeTag< 1 >, const Mask128< T > m)
Definition wasm_128-inl.h:3661
HWY_INLINE Mask128< T, N > And(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition x86_128-inl.h:815
HWY_INLINE void UnpoisonIfMemorySanitizer(void *p, size_t bytes)
Definition vqsort-inl.h:73
void HeapSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes)
Definition vqsort-inl.h:127
HWY_INLINE bool AllFalse(hwy::SizeTag< 1 >, const Mask256< T > mask)
Definition x86_256-inl.h:4543
HWY_INLINE Mask128< T, N > Or(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition x86_128-inl.h:889
HWY_INLINE Mask128< T, N > AndNot(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition x86_128-inl.h:852
HWY_INLINE size_t CountTrue(hwy::SizeTag< 1 >, const Mask128< T > mask)
Definition arm_neon-inl.h:5609
HWY_INLINE Vec128< T, N > IfThenElse(hwy::SizeTag< 1 >, Mask128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition x86_128-inl.h:670
constexpr bool IsFull(Simd< T, N, kPow2 >)
Definition ops/shared-inl.h:115
HWY_INLINE Mask512< T > Not(hwy::SizeTag< 1 >, const Mask512< T > m)
Definition x86_512-inl.h:1613
HWY_INLINE Mask128< T, N > Xor(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition x86_128-inl.h:926
d
Definition rvv-inl.h:1998
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition arm_neon-inl.h:2456
typename detail::CappedTagChecker< T, kLimit >::type CappedTag
Definition ops/shared-inl.h:184
void Print(const D d, const char *caption, VecArg< V > v, size_t lane_u=0, size_t max_lanes=7)
Definition print-inl.h:39
HWY_API void BlendedStore(Vec128< T, N > v, Mask128< T, N > m, Simd< T, N, 0 > d, T *HWY_RESTRICT p)
Definition arm_neon-inl.h:2941
HWY_API constexpr size_t Lanes(Simd< T, N, kPow2 >)
Definition arm_sve-inl.h:243
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition arm_neon-inl.h:2753
void Sort(D d, Traits st, T *HWY_RESTRICT keys, size_t num, T *HWY_RESTRICT buf)
Definition vqsort-inl.h:1438
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:2772
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:2591
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition arm_neon-inl.h:1020
HWY_API TFromV< V > GetLane(const V v)
Definition arm_neon-inl.h:1076
typename detail::FixedTagChecker< T, kNumLanes >::type FixedTag
Definition ops/shared-inl.h:200
HWY_API void SafeCopyN(const size_t num, D d, const T *HWY_RESTRICT from, T *HWY_RESTRICT to)
Definition generic_ops-inl.h:111
HWY_API size_t CompressStore(Vec128< T, N > v, const Mask128< T, N > mask, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:6248
N
Definition rvv-inl.h:1998
HWY_API size_t CompressBlendedStore(Vec128< T, N > v, Mask128< T, N > m, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:6257
HWY_API size_t FindKnownFirstTrue(const Simd< T, N, 0 > d, const Mask128< T, N > mask)
Definition arm_neon-inl.h:5683
HWY_API void Store(Vec128< T, N > v, Simd< T, N, 0 > d, T *HWY_RESTRICT aligned)
Definition arm_neon-inl.h:2934
const vfloat64m1_t v
Definition rvv-inl.h:1998
decltype(Zero(D())) Vec
Definition generic_ops-inl.h:40
Definition aligned_allocator.h:27
HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T *p)
Definition cache_control.h:77
HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x)
Definition base.h:831
constexpr size_t CeilLog2(TI x)
Definition base.h:899
#define HWY_MAX_BYTES
Definition set_macros-inl.h:84
#define HWY_LANES(T)
Definition set_macros-inl.h:85
#define HWY_ALIGN
Definition set_macros-inl.h:83
#define HWY_NAMESPACE
Definition set_macros-inl.h:82
Definition arm_neon-inl.h:5729
Definition contrib/sort/shared-inl.h:28
static constexpr size_t kMaxCols
Definition contrib/sort/shared-inl.h:34
static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t)
Definition contrib/sort/shared-inl.h:69
static constexpr size_t kMaxRows
Definition contrib/sort/shared-inl.h:43
static constexpr HWY_INLINE size_t BaseCaseNum(size_t N)
Definition contrib/sort/shared-inl.h:45
static constexpr size_t kMaxRowsLog2
Definition contrib/sort/shared-inl.h:42
static constexpr size_t kPartitionUnroll
Definition contrib/sort/shared-inl.h:54
HWY_AFTER_NAMESPACE()
#define VQSORT_PRINT
Definition vqsort-inl.h:21
HWY_BEFORE_NAMESPACE()