Deep Understanding of Spark 2.1 Core: TimSort Principle and Source Code Analysis

In the blog Deep Understanding of Spark 2.1 Core (X): Principles and Source Code Analysis of Shuffle Map End We mentioned that:

Sort and others are used to sort the data, and TimSort is used.
In this blog post, let's take a deeper look at Tim Sort

Understanding timsort

After watching the video, you may find that TimSort and MageSort are very similar. Yes, you'll find that it's just a series of improvements to merge sort. Some of them are very smart, while others are quite straightforward. These large and small improvements aggregate to make the efficiency of the algorithm very attractive.

Spark TimSort Source Analysis

In fact, OpenJDK also uses TimSort in Arrays of Java SE 7 about array of Object elements, while TimSort written in Java in Spark's org.apache.spark.util.collection package is not much different from TimSort in Java SE 7.

  public void sort(Buffer a, int lo, int hi, Comparator<? super K> c) {
    assert c != null;
    // Unordered array length
    int nRemaining  = hi - lo;
    // If the array size is 0 or 1
    // So it's sort.
    if (nRemaining < 2)
      return;  

    // If it's a decimal array
    // No merge sort is used
    if (nRemaining < MIN_MERGE) {
    // Get the length of the incremental sequence
      int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
      // Bipartite insertion sort
      binarySort(a, lo, hi, lo + initRunLen, c);
      return;
    }
    // Stack
    SortState sortState = new SortState(a, c, hi - lo);
    // Get the Minimum run Length
    int minRun = minRunLength(nRemaining);
    do {
      // Get the length of the incremental sequence
      int runLen = countRunAndMakeAscending(a, lo, hi, c);

      // If run is too small,
      // Using Binary Insertion Sorting
      if (runLen < minRun) {
        int force = nRemaining <= minRun ? nRemaining : minRun;
        binarySort(a, lo, lo + force, lo + runLen, c);
        runLen = force;
      }

      // Push
      sortState.pushRun(lo, runLen);
      // Possible merger
      sortState.mergeCollapse();

      // Pre-operation to find the next run
      lo += runLen;
      nRemaining -= runLen;
    } while (nRemaining != 0);

    // Merge all the remaining run s to complete the sorting
    assert lo == hi;
    sortState.mergeForceCollapse();
    assert sortState.stackSize == 1;
  }

Let's go on to explain one by one.

countRunAndMakeAscending

  private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator<? super K> c) {
    assert lo < hi;
    int runHi = lo + 1;
    if (runHi == hi)
      return 1;

    K key0 = s.newKey();
    K key1 = s.newKey();

    // Find the tail of run
    if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { 
    // If it decreases, find the tail reverse run
      while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0)
        runHi++;
      reverseRange(a, lo, runHi);
    } else {                              
      while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0)
        runHi++;
    }
    // Return the length of run
    return runHi - lo;
  }

binarySort

  private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super K> c) {
    assert lo <= start && start <= hi;
    if (start == lo)
      start++;

    K key0 = s.newKey();
    K key1 = s.newKey();

    Buffer pivotStore = s.allocate(1);
    // Insert the element dichotomy on the position [start, hi] into the ordered [lo, start] sequence
    for ( ; start < hi; start++) {
      s.copyElement(a, start, pivotStore, 0);
      K pivot = s.getKey(pivotStore, 0, key0);

      int left = lo;
      int right = start;
      assert left <= right;
      while (left < right) {
        int mid = (left + right) >>> 1;
        if (c.compare(pivot, s.getKey(a, mid, key1)) < 0)
          right = mid;
        else
          left = mid + 1;
      }
      assert left == right;

      int n = start - left;  
      // Simple optimization for insertion
      switch (n) {
        case 2:  s.copyElement(a, left + 1, a, left + 2);
        case 1:  s.copyElement(a, left, a, left + 1);
          break;
        default: s.copyRange(a, left, a, left + 1, n);
      }
      s.copyElement(pivotStore, 0, a, left);
    }
  }

minRunLength

  private int minRunLength(int n) {
    assert n >= 0;
    int r = 0;     
    // Here MIN_MERGE is a power of 2.
    // if n < MIN_MERGE ,
    // The n returns n directly
    // Other if n >= MIN_MERGE and n(>1) is a power of 2,
    // The binary low bit 1 of then n is 0, R |= (n & 1) is always 0, that is, MIN_MERGE/2 is returned.
    // Other R is the binary low first bit value k of n in the next cycle, and the return value MIN_MERGE/2 < K < MIN_MERGE 
    while (n >= MIN_MERGE) {
      r |= (n & 1);
      n >>= 1;
    }
    return n + r;
  }

SortState.pushRun

Push

    private void pushRun(int runBase, int runLen) {
      this.runBase[stackSize] = runBase;
      this.runLen[stackSize] = runLen;
      stackSize++;
    }

SortState.mergeCollapse

There are bug s in this part of the code OpenJDK. Let's first look at how Java SE 7 is implemented:

private void mergeCollapse() {
    while (stackSize > 1) {
        int n = stackSize - 2;
        if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
            if (runLen[n - 1] < runLen[n + 1])
                n--;
            mergeAt(n);
        } else if (runLen[n] <= runLen[n + 1]) {
            mergeAt(n);
        } else {
            break; 
        }
    }
}

Let's give an example:
When the length of the fragment in the stack is:

120, 80, 25, 20

We insert 30 fragments of length, and because 25 < 20 + 30 and 25 < 30, we get:

120, 80, 45, 30

Now, because 80 > 45 + 30 and 45 > 30, the merger ends. But this does not fully comply with the re-storage based on the invariant, because 120 < 80 + 45!

Spark also fixes the bug. The code after fixing is as follows:

    private void mergeCollapse() {
      while (stackSize > 1) {
        int n = stackSize - 2;
        if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1])
          || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) {
          if (runLen[n - 1] < runLen[n + 1])
            n--;
        } else if (runLen[n] > runLen[n + 1]) {
          break; 
        }
        mergeAt(n);
      }
    }

SortState. mergeAt

    private void mergeAt(int i) {
      assert stackSize >= 2;
      assert i >= 0;
      assert i == stackSize - 2 || i == stackSize - 3;

      int base1 = runBase[i];
      int len1 = runLen[i];
      int base2 = runBase[i + 1];
      int len2 = runLen[i + 1];
      assert len1 > 0 && len2 > 0;
      assert base1 + len1 == base2;

      // If i is the third position from the top of the stack
      // Then assign the top element of the stack to the second position from the top of the stack
      runLen[i] = len1 + len2;
      if (i == stackSize - 3) {
        runBase[i + 1] = runBase[i + 2];
        runLen[i + 1] = runLen[i + 2];
      }
      stackSize--;

      K key0 = s.newKey();

       // Find the location of the first element of run2 from run1
       // Previous run1 elements can be ignored
      int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c);
      assert k >= 0;
      base1 += k;
      len1 -= k;
      if (len1 == 0)
        return;

      // Find the location of the last element of run1 from run2
      // After that, the elements of run2 can be ignored.
      len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c);
      assert len2 >= 0;
      if (len2 == 0)
        return;

      // Merge run
      // Temporary arrays using min(len1, len2) lengths
      if (len1 <= len2)
        mergeLo(base1, len1, base2, len2);
      else
        mergeHi(base1, len1, base2, len2);
    }

SortState. gallopRight

   // Key: the first value of run2
   // a: Array
   // base: run1 starts at the location
   // Length of len: run1
   // Hint: Start with the hint location of run1, where we pass in a value of 0 
    private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator<? super K> c) {
      assert len > 0 && hint >= 0 && hint < len;
      
      // Optimizing Binary Search:
      // We're going to intercept such an array from run1
      // lastOfs = k+1
      // ofs = 2×k+1
      // run1[lastOfs] <= key <= run1[ofs]
      // In [last Ofs, ofs], do binary search
      int ofs = 1;
      int lastOfs = 0;
      K key1 = s.newKey();
      
      // If the first value of run2 is less than the first value of Run1
      // Actually, I know that I can go back to 0 directly.
      // But there's still a complete algorithm flow. 
      if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) {
        // maxOfs = 1
        int maxOfs = hint + 1;
        // Do not enter the cycle
        while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) {
          lastOfs = ofs;
          ofs = (ofs << 1) + 1;
          if (ofs <= 0)   
            ofs = maxOfs;
        }
        // No entry
        if (ofs > maxOfs)
          ofs = maxOfs;

        // tmp = 0
        int tmp = lastOfs;
        // lastOfs = -1
        lastOfs = hint - ofs;
        // ofs = 0
        ofs = hint - tmp;
      } else { 
      // In this case, the algorithm will play a real role.
      // maxOfs = len
        int maxOfs = len - hint;
        while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) {
         // Update lastOfs and ofs
          lastOfs = ofs;
          ofs = (ofs << 1) + 1;
          // Preventing spillovers
          if (ofs <= 0)   
            ofs = maxOfs;
        }
        if (ofs > maxOfs)
          ofs = maxOfs;

        // It won't change here.
        lastOfs += hint;
        ofs += hint;
      }
      assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;

      // Binary search
      lastOfs++;
      while (lastOfs < ofs) {
        int m = lastOfs + ((ofs - lastOfs) >>> 1);

        if (c.compare(key, s.getKey(a, base + m, key1)) < 0)
        // key < a[b + m]
          ofs = m;          
        else
        // a[b + m] <= key
          lastOfs = m + 1;  
      }
      assert lastOfs == ofs;    
      return ofs;
    }

gallopLeft, similar to the above code, is no longer explained.

SortState. mergeLo

    private void mergeLo(int base1, int len1, int base2, int len2) {
      assert len1 > 0 && len2 > 0 && base1 + len1 == base2;

      // Temporary arrays using min(len1, len2) lengths
      // Here len1 will be smaller
      Buffer a = this.a; 
      Buffer tmp = ensureCapacity(len1);
      s.copyRange(a, base1, tmp, 0, len1);

     // Pointer on tmp (run1)
      int cursor1 = 0;       
     // run2 pointer
      int cursor2 = base2;   
     // Guidelines on merging results
      int dest = base1;      

      // Move first element of second run and deal with degenerate cases
      // Optimize:
      // Note: The first element of run2 is smaller than the first element of run1.
      //       The last element of run1 is larger than the last element of run2.   
      // Copy the first element of run2 to the first position of the final result
      s.copyElement(a, cursor2++, a, dest++);
      if (--len2 == 0) {
      // If len 2 is 1
      // Copy run1 directly to the final result
        s.copyRange(tmp, cursor1, a, dest, len1);
        return;
      }
      if (len1 == 1) {
      // If len 1 is 1
      // Copy the rest of run2 to the final result
      // Copy run1 to the final result
        s.copyRange(a, cursor2, a, dest, len2);
        s.copyElement(tmp, cursor1, a, dest + len2); 
        return;
      }

      K key0 = s.newKey();
      K key1 = s.newKey();

      Comparator<? super K> c = this.c;
      // Optimizing merge ranking:  
      int minGallop = this.minGallop;    
      outer:
      while (true) {
      // The main idea is to count inserts using count1 count2
        int count1 = 0; 
        int count2 = 0; 

        do {
        // Merge
          assert len1 > 1 && len2 > 0;
          if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) {
            s.copyElement(a, cursor2++, a, dest++);
            count2++;
            count1 = 0;
            if (--len2 == 0)
              break outer;
          } else {
            s.copyElement(tmp, cursor1++, a, dest++);
            count1++;
            count2 = 0;
            if (--len1 == 1)
              break outer;
          }
          // If the number of consecutive copies of a run exceeds minGallop
          // Exit the cycle
        } while ((count1 | count2) < minGallop);

         // We think that if the number of successive copies of a run exceeds that of minGallop,
         // It may also occur if the number of consecutive copies of a run exceeds minGallop.
         // All operations similar to those in mergeAt need to be reworked.
         // Intercept and merge according to "segment"
         // Until count1 or COUNT2 < MIN_GALLOP
        do {
          assert len1 > 1 && len2 > 0;
          count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c);
          if (count1 != 0) {
            s.copyRange(tmp, cursor1, a, dest, count1);
            dest += count1;
            cursor1 += count1;
            len1 -= count1;
            if (len1 <= 1) // len1 == 1 || len1 == 0
              break outer;
          }
          s.copyElement(a, cursor2++, a, dest++);
          if (--len2 == 0)
            break outer;

          count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c);
          if (count2 != 0) {
            s.copyRange(a, cursor2, a, dest, count2);
            dest += count2;
            cursor2 += count2;
            len2 -= count2;
            if (len2 == 0)
              break outer;
          }
          s.copyElement(tmp, cursor1++, a, dest++);
          if (--len1 == 1)
            break outer;
          minGallop--;
        } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
        // Adjusting minGallop
        if (minGallop < 0)
          minGallop = 0;
        minGallop += 2;  
      }  
      // Exit the outer loop
      this.minGallop = minGallop < 1 ? 1 : minGallop;                         
       
      // Write the tail into the final result
      if (len1 == 1) {
        assert len2 > 0;
        s.copyRange(a, cursor2, a, dest, len2);
        s.copyElement(tmp, cursor1, a, dest + len2); 
      } else if (len1 == 0) {
        throw new IllegalArgumentException(
            "Comparison method violates its general contract!");
      } else {
        assert len2 == 0;
        assert len1 > 1;
        s.copyRange(tmp, cursor1, a, dest, len1);
      }
    }

mergeHi is similar to the above, so it will not be explained.

SortState.mergeForceCollapse

    private void mergeForceCollapse() {
    // Merge all run s
      while (stackSize > 1) {
        int n = stackSize - 2;
        // If the third run length is less than the run at the top of the stack
        // Merge 2,3 run s first
        if (n > 0 && runLen[n - 1] < runLen[n + 1])
          n--;
        mergeAt(n);
      }
    }

summary

In Spark Tim Sort, MergeSort has the following general points:

  • Element: Unlike MergeSort, which is inert, the original length is 1, and merging automatically generates new merging elements. TimSort is a merging element, run, in which fragments are incremented in advance (or reversed in successive decreases).
  • Insert Sort: If run is small in length, TimSort will switch to Binary InsertSort and make some minor optimizations for it instead of MergeSort.
  • Merge timing: MergeSort's merge timing is fatal, while TimSort's merge timing is (n >= 1 & & runLen [n-1] <= runLen [n] + runLen [n + 1]] | (n >= 2 & & runLen [n-2] <= runLen [n] + runLen [n + 1]) | (n > 2 &runLen [n-2] <= runLen [n] + runLen [n-1]. And, if the length of the third run from the top of the stack is less than the run on the top of the stack, it will be merged into the second and third run.
  • Cut out the fragments that need to be merged: Run1 is the head and the tail of run2 have parts that can not be merged. For example, TimSort intercepts a fragment from run1: lastOfs = k+1, ofs = 2 *k+1, Run1 [lastOfs] <= key <= Run1 [ofs]. Then the binary search is performed on the fragment to get the starting position of merging in run1.
  • Merge optimization: When run length is 1, small optimization is carried out. The collaboration of merging by single value and by fragment is realized.

Tags: Spark Java Fragment less

Posted on Fri, 06 Sep 2019 23:04:59 -0700 by HokieTracks