diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir @@ -9,10 +9,17 @@ // Each sparse vector is represented by an index memref (A or C) and by a data // memref (B or D), containing M or N elements. // -// There are two implementations: +// There are four different implementations: // * `memref_dot_simple`: Simple O(N*M) implementation with two for loops. // * `memref_dot_optimized`: An optimized O(N*M) version of the previous // implementation, where the second for loop skips over some elements. +// * `memref_dot_while`: An optimized O(N+M) implementation that utilizes a +// single while loop, coiterating over both vectors. +// * `memref_dot_while_branchless`: An optimized O(N+M) implementation that +// consists of a single while loop and has no branches within the loop. +// +// Output of llvm-mca: +// https://gist.github.com/matthias-springer/72e7ee1b3c467e7aefb6e1fd862e4841 #contraction_accesses = [ affine_map<(i) -> (i)>, @@ -224,6 +231,166 @@ return %r0 : f64 } +// Vector dot product with a while loop. Implemented as follows: +// +// r = 0.0, a = 0, b = 0 +// while (a < M && b < N) { +// segA = A[a:a+8], segB = B[b:b+8] +// if (segB[7] < segA[0]) b += 8 +// elif (segA[7] < segB[0]) a += 8 +// else { +// r += vector_dot(...) +// if (segA[7] < segB[7]) a += 8 +// elif (segB[7] < segA[7]) b += 8 +// else a += 8, b += 8 +// } +// } +func @memref_dot_while(%m_A : memref, %m_B : memref, + %m_C : memref, %m_D : memref, + %M : index, %N : index) + -> f64 { + // Helper constants for loops. + %c0 = constant 0 : index + %i0 = constant 0 : i32 + %i7 = constant 7 : i32 + %c8 = constant 8 : index + + %data_zero = constant 0.0 : f64 + %index_padding = constant 9223372036854775807 : i64 + + %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0) + : (f64, index, index) -> (f64, index, index) { + %cond_i = cmpi "slt", %a1, %M : index + %cond_j = cmpi "slt", %b1, %N : index + %cond = and %cond_i, %cond_j : i1 + scf.condition(%cond) %r1, %a1, %b1 : f64, index, index + } do { + ^bb0(%r1 : f64, %a1 : index, %b1 : index): + // v_A, v_B, seg*_* could be part of the loop state to avoid a few + // redundant reads. + %v_A = vector.transfer_read %m_A[%a1], %index_padding + : memref, vector<8xi64> + %v_C = vector.transfer_read %m_C[%b1], %index_padding + : memref, vector<8xi64> + + %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> + %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> + %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64> + %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + + %seg1_done = cmpi "slt", %segB_max, %segA_min : i64 + %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) { + %b3 = addi %b1, %c8 : index + scf.yield %r1, %a1, %b3 : f64, index, index + } else { + %seg0_done = cmpi "slt", %segA_max, %segB_min : i64 + %r4, %a4, %b4 = scf.if %seg0_done -> (f64, index, index) { + %a5 = addi %a1, %c8 : index + scf.yield %r1, %a5, %b1 : f64, index, index + } else { + %v_B = vector.transfer_read %m_B[%a1], %data_zero + : memref, vector<8xf64> + %v_D = vector.transfer_read %m_D[%b1], %data_zero + : memref, vector<8xf64> + + %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D) + : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) + -> f64 + %r6 = addf %r1, %subresult : f64 + + %incr_a = cmpi "slt", %segA_max, %segB_max : i64 + %a6, %b6 = scf.if %incr_a -> (index, index) { + %a7 = addi %a1, %c8 : index + scf.yield %a7, %b1 : index, index + } else { + %incr_b = cmpi "slt", %segB_max, %segA_max : i64 + %a8, %b8 = scf.if %incr_b -> (index, index) { + %b9 = addi %b1, %c8 : index + scf.yield %a1, %b9 : index, index + } else { + %a10 = addi %a1, %c8 : index + %b10 = addi %b1, %c8 : index + scf.yield %a10, %b10 : index, index + } + scf.yield %a8, %b8 : index, index + } + scf.yield %r6, %a6, %b6 : f64, index, index + } + scf.yield %r4, %a4, %b4 : f64, index, index + } + scf.yield %r2, %a2, %b2 : f64, index, index + } + + return %r0 : f64 +} + +// Vector dot product with a while loop that has no branches (apart from the +// while loop itself). Implemented as follows: +// +// r = 0.0, a = 0, b = 0 +// while (a < M && b < N) { +// segA = A[a:a+8], segB = B[b:b+8] +// r += vector_dot(...) +// a += (segA[7] <= segB[7]) * 8 +// b += (segB[7] <= segA[7]) * 8 +// } +func @memref_dot_while_branchless(%m_A : memref, %m_B : memref, + %m_C : memref, %m_D : memref, + %M : index, %N : index) + -> f64 { + // Helper constants for loops. + %c0 = constant 0 : index + %i7 = constant 7 : i32 + %c8 = constant 8 : index + + %data_zero = constant 0.0 : f64 + %index_padding = constant 9223372036854775807 : i64 + + %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0) + : (f64, index, index) -> (f64, index, index) { + %cond_i = cmpi "slt", %a1, %M : index + %cond_j = cmpi "slt", %b1, %N : index + %cond = and %cond_i, %cond_j : i1 + scf.condition(%cond) %r1, %a1, %b1 : f64, index, index + } do { + ^bb0(%r1 : f64, %a1 : index, %b1 : index): + // v_A, v_B, seg*_* could be part of the loop state to avoid a few + // redundant reads. + %v_A = vector.transfer_read %m_A[%a1], %index_padding + : memref, vector<8xi64> + %v_B = vector.transfer_read %m_B[%a1], %data_zero + : memref, vector<8xf64> + %v_C = vector.transfer_read %m_C[%b1], %index_padding + : memref, vector<8xi64> + %v_D = vector.transfer_read %m_D[%b1], %data_zero + : memref, vector<8xf64> + + %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D) + : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) + -> f64 + %r2 = addf %r1, %subresult : f64 + + %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> + %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + + %cond_a = cmpi "sle", %segA_max, %segB_max : i64 + %cond_a_i64 = zexti %cond_a : i1 to i64 + %cond_a_idx = index_cast %cond_a_i64 : i64 to index + %incr_a = muli %cond_a_idx, %c8 : index + %a2 = addi %a1, %incr_a : index + + %cond_b = cmpi "sle", %segB_max, %segA_max : i64 + %cond_b_i64 = zexti %cond_b : i1 to i64 + %cond_b_idx = index_cast %cond_b_i64 : i64 to index + %incr_b = muli %cond_b_idx, %c8 : index + %b2 = addi %b1, %incr_b : index + + scf.yield %r2, %a2, %b2 : f64, index, index + } + + return %r0 : f64 +} + func @entry() -> i32 { // Initialize large buffers that can be used for multiple test cases of // different sizes. @@ -256,6 +423,18 @@ vector.print %r1 : f64 // CHECK: 86 + %r2 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M1, %N1) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r2 : f64 + // CHECK: 86 + + %r6 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M1, %N1) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r6 : f64 + // CHECK: 86 + // --- Test case 2 ---. // M and N must be a multiple of 8 if smaller than 128. // (Because padding kicks in only for out-of-bounds accesses.) @@ -275,6 +454,18 @@ vector.print %r4 : f64 // CHECK: 111 + %r5 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M2, %N2) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r5 : f64 + // CHECK: 111 + + %r7 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M2, %N2) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r7 : f64 + // CHECK: 111 + // Release all resources. dealloc %b_A : memref<128xi64> dealloc %b_B : memref<128xf64>