diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -128,6 +128,19 @@ }; return *this; } + + /// Function that returns the traversal order (in terms of "for loop order", + /// i.e. slowest varying dimension to fastest varying dimension) that shoudl + /// be used when unrolling the given operation into units of the native vector + /// size. + using UnrollTraversalOrderFnType = + std::function>(Operation *op)>; + UnrollTraversalOrderFnType traversalOrderCallback = nullptr; + UnrollVectorOptions & + setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) { + traversalOrderCallback = std::move(traversalOrderFn); + return *this; + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp @@ -15,8 +15,11 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Support/MathExtras.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include #define DEBUG_TYPE "vector-unrolling" @@ -36,20 +39,78 @@ return elementOffsets; } +/// A functor that accomplishes the same thing as `getVectorOffset` but allows +/// for reordering the traversal of the dimensions. The order of traversal is +/// given in "for loop order" (outer to inner). +namespace { +class DecomposeShapeIterator { +private: + SmallVector vectorShape; + SmallVector loopOrder; + SmallVector sliceStrides; + int64_t maxIndexVal{1}; + +public: + DecomposeShapeIterator(ArrayRef originalShape, + ArrayRef targetShape, + ArrayRef loopOrder) + : vectorShape(targetShape.begin(), targetShape.end()), + loopOrder(loopOrder.begin(), loopOrder.end()), + sliceStrides(originalShape.size()) { + // Compute the count for each dimension. + SmallVector sliceDimCounts(originalShape.size()); + for (unsigned r = 0; r < originalShape.size(); ++r) { + sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]); + maxIndexVal *= sliceDimCounts[r]; + } + + // Reversing "loop order" gives dimensions from fastest varying to slowest + // varying (smallest stride to largest stride). + int64_t accum = 1; + for (auto idx : llvm::reverse(loopOrder)) { + sliceStrides[idx] = accum; + accum *= sliceDimCounts[idx]; + } + } + + // Turn the linear index into a d-tuple based on units of vectors of size + // `vectorShape`. The linear index is assumed to represent traversal of the + // dimensions based on `order`. + SmallVector delinearize(int64_t index) const { + // Traverse in for loop order (largest stride to smallest stride). + SmallVector vectorOffsets(sliceStrides.size()); + for (auto idx : loopOrder) { + vectorOffsets[idx] = index / sliceStrides[idx]; + index %= sliceStrides[idx]; + } + return vectorOffsets; + } + + int64_t maxIndex() const { return maxIndexVal; } + + /// Return the offset within d-tuple based on the ordering given by + /// `loopOrder`. + SmallVector getVectorOffset(int64_t index) const { + SmallVector vectorOffsets = delinearize(index); + SmallVector elementOffsets = + computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets); + return elementOffsets; + } +}; +} // namespace + /// Compute the indices of the slice `index` for a tranfer op. -static SmallVector -sliceTransferIndices(int64_t index, ArrayRef originalShape, - ArrayRef targetShape, ArrayRef indices, - AffineMap permutationMap, Location loc, - OpBuilder &builder) { +static SmallVector sliceTransferIndices(ArrayRef elementOffsets, + ArrayRef indices, + AffineMap permutationMap, + Location loc, + OpBuilder &builder) { MLIRContext *ctx = builder.getContext(); auto isBroadcast = [](AffineExpr expr) { if (auto constExpr = expr.dyn_cast()) return constExpr.getValue() == 0; return false; }; - SmallVector elementOffsets = - getVectorOffset(originalShape, targetShape, index); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. SmallVector slicedIndices(indices.begin(), indices.end()); for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { @@ -99,6 +160,20 @@ return targetShape; } +static SmallVector +getUnrollOrder(unsigned numLoops, Operation *op, + const vector::UnrollVectorOptions &options) { + SmallVector loopOrder = + llvm::to_vector(llvm::seq(0, static_cast(numLoops))); + if (options.traversalOrderCallback != nullptr) { + Optional> order = options.traversalOrderCallback(op); + if (order.hasValue()) { + loopOrder = std::move(*order); + } + } + return loopOrder; +} + namespace { struct UnrollTransferReadPattern @@ -122,8 +197,7 @@ Location loc = readOp.getLoc(); ArrayRef originalSize = readOp.getVectorType().getShape(); SmallVector ratio = *shapeRatio(originalSize, *targetShape); - // Compute shape ratio of 'shape' and 'sizes'. - int64_t sliceCount = computeMaxLinearIndex(ratio); + // Prepare the result vector; Value result = rewriter.create( loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); @@ -131,17 +205,22 @@ VectorType::get(*targetShape, sourceVectorType.getElementType()); SmallVector originalIndices(readOp.getIndices().begin(), readOp.getIndices().end()); - for (int64_t i = 0; i < sliceCount; i++) { + + SmallVector loopOrder = + getUnrollOrder(ratio.size(), readOp, options); + DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, + loopOrder); + for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { + SmallVector elementOffsets = + indexToOffsets.getVectorOffset(i); SmallVector indices = - sliceTransferIndices(i, originalSize, *targetShape, originalIndices, + sliceTransferIndices(elementOffsets, originalIndices, readOp.getPermutationMap(), loc, rewriter); auto slicedRead = rewriter.create( loc, targetType, readOp.getSource(), indices, readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr()); - SmallVector elementOffsets = - getVectorOffset(originalSize, *targetShape, i); result = rewriter.create( loc, slicedRead, result, elementOffsets, strides); } @@ -174,20 +253,21 @@ SmallVector strides(targetShape->size(), 1); Location loc = writeOp.getLoc(); ArrayRef originalSize = sourceVectorType.getShape(); - SmallVector ratio = *shapeRatio(originalSize, *targetShape); - // Compute shape ratio of 'shape' and 'sizes'. - int64_t sliceCount = computeMaxLinearIndex(ratio); SmallVector originalIndices(writeOp.getIndices().begin(), writeOp.getIndices().end()); + + SmallVector loopOrder = + getUnrollOrder(originalIndices.size(), writeOp, options); + DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, + loopOrder); Value resultTensor; - for (int64_t i = 0; i < sliceCount; i++) { + for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { SmallVector elementOffsets = - getVectorOffset(originalSize, *targetShape, i); + indexToOffsets.getVectorOffset(i); Value slicedVector = rewriter.create( loc, writeOp.getVector(), elementOffsets, *targetShape, strides); - SmallVector indices = - sliceTransferIndices(i, originalSize, *targetShape, originalIndices, + sliceTransferIndices(elementOffsets, originalIndices, writeOp.getPermutationMap(), loc, rewriter); Operation *slicedWrite = rewriter.create( loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), @@ -238,8 +318,6 @@ SmallVector originalSize = *contractOp.getShapeForUnroll(); SmallVector ratio = *shapeRatio(originalSize, *targetShape); - // Compute shape ratio of 'shape' and 'sizes'. - int64_t sliceCount = computeMaxLinearIndex(ratio); Location loc = contractOp.getLoc(); unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex]; @@ -247,9 +325,14 @@ SmallVector, Value, llvm::DenseMap, unsigned, OffsetMapInfo>> accCache; + + SmallVector loopOrder = getUnrollOrder( + contractOp.getIndexingMaps().size(), contractOp, options); + DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, + loopOrder); + const int64_t sliceCount = indexToOffsets.maxIndex(); for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = - getVectorOffset(originalSize, *targetShape, i); + SmallVector offsets = indexToOffsets.getVectorOffset(i); SmallVector slicesOperands(contractOp.getNumOperands()); // Helper to coompute the new shape of each operand and extract the slice. diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns=reverse-unroll-order --split-input-file | FileCheck %s --check-prefix=ORDER // CHECK-LABEL: func @transfer_read_unroll // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -13,6 +14,19 @@ // CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> // CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32> +// ORDER-LABEL: func @transfer_read_unroll +// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index +// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index +// ORDER: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// ORDER-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// ORDER-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// ORDER-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// ORDER-NEXT: return %[[VEC3]] : vector<4x4xf32> + func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 @@ -33,6 +47,19 @@ // CHECK-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> // CHECK-NEXT: return +// ORDER-LABEL: func @transfer_write_unroll +// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index +// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index +// ORDER: %[[S0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// ORDER-NEXT: vector.transfer_write %[[S0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// ORDER-NEXT: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// ORDER-NEXT: vector.transfer_write %[[S1]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// ORDER-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// ORDER-NEXT: vector.transfer_write %[[S2]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// ORDER-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// ORDER-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// ORDER-NEXT: return + func.func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> @@ -222,6 +249,25 @@ // CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref, vector<2x2xf32> // CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: return %[[VEC5]] : vector<6x4xf32> + +// ORDER-LABEL: func @transfer_read_unroll_different_rank +// ORDER-DAG: %[[C4:.*]] = arith.constant 4 : index +// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index +// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index +// ORDER: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref, vector<2x2xf32> +// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref, vector<2x2xf32> +// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref, vector<2x2xf32> +// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref, vector<2x2xf32> +// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref, vector<2x2xf32> +// ORDER-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref, vector<2x2xf32> +// ORDER-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> +// ORDER-NEXT: return %[[VEC5]] : vector<6x4xf32> + #map0 = affine_map<(d0, d1, d2) -> (d2, d0)> func.func @transfer_read_unroll_different_rank(%arg0 : memref) -> vector<6x4xf32> { %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -1,50 +1,156 @@ // RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s +// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=2,0,1" --split-input-file | FileCheck %s --check-prefix=ORDER -func.func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>, +func.func @vector_contract_f32(%lhs : vector<8x4xf32>, %rhs : vector<8x4xf32>, %init : vector<8x8xf32>) -> vector<8x8xf32> { %0 = vector.contract {indexing_maps = [affine_map<(i, j, k) -> (i, k)>, affine_map<(i, j, k) -> (j, k)>, affine_map<(i, j, k) -> (i, j)>], iterator_types = ["parallel", "parallel", "reduction"]} - %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32> + %lhs, %rhs, %init : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32> return %0 : vector<8x8xf32> } // CHECK-LABEL: func @vector_contract_f32 -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { +// CHECK-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32> + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [0, 0] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [0, 0] +// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// CHECK-SAME: offsets = [0, 0] +// CHECK: [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [0, 2] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [0, 2] +// CHECK: [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [0, 0] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [4, 0] +// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// CHECK-SAME: offsets = [0, 4] +// CHECK: [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [0, 2] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [4, 2] +// CHECK: [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [4, 0] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [0, 0] +// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// CHECK-SAME: offsets = [4, 0] +// CHECK: [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [4, 2] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [0, 2] +// CHECK: [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum5]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [4, 0] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [4, 0] +// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// CHECK-SAME: offsets = [4, 4] +// CHECK: [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [4, 2] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [4, 2] +// CHECK: [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum7]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + // CHECK: return +// ORDER-LABEL: func @vector_contract_f32 +// ORDER-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [0, 0] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [0, 0] +// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// ORDER-SAME: offsets = [0, 0] +// ORDER: [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [0, 0] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [4, 0] +// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// ORDER-SAME: offsets = [0, 4] +// ORDER: [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [4, 0] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [0, 0] +// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// ORDER-SAME: offsets = [4, 0] +// ORDER: [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [4, 0] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [4, 0] +// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// ORDER-SAME: offsets = [4, 4] +// ORDER: [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [0, 2] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [0, 2] +// ORDER: [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [0, 2] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [4, 2] +// ORDER: [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum2]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [4, 2] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [0, 2] +// ORDER: [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [4, 2] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [4, 2] +// ORDER: [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum4]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: return + + + func.func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>, %init : vector<8x8xf16>) -> vector<8x8xf16> { %0 = vector.contract @@ -158,3 +264,4 @@ // CHECK: %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> // CHECK: return %[[V7]] : vector<2x3x8x4xf32> + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Pass/Pass.h" @@ -322,12 +323,18 @@ } return nativeShape; }; - populateVectorUnrollPatterns(patterns, - UnrollVectorOptions() - .setNativeShapeFn(nativeShapeFn) - .setFilterConstraint([](Operation *op) { - return success(isa(op)); - })); + + UnrollVectorOptions opts; + opts.setNativeShapeFn(nativeShapeFn) + .setFilterConstraint( + [](Operation *op) { return success(isa(op)); }); + if (!unrollOrder.empty()) { + opts.setUnrollTraversalOrderFn([this](Operation *op) + -> Optional> { + return SmallVector{unrollOrder.begin(), unrollOrder.end()}; + }); + } + populateVectorUnrollPatterns(patterns, opts); } else { populateVectorUnrollPatterns( patterns, UnrollVectorOptions() @@ -340,6 +347,10 @@ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } + ListOption unrollOrder{*this, "unroll-order", + llvm::cl::desc("set the unroll order"), + llvm::cl::ZeroOrMore}; + Option unrollBasedOnType{ *this, "unroll-based-on-type", llvm::cl::desc("Set the unroll factor based on type of the operation"), @@ -472,6 +483,11 @@ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestVectorTransferUnrollingPatterns) + TestVectorTransferUnrollingPatterns() = default; + TestVectorTransferUnrollingPatterns( + const TestVectorTransferUnrollingPatterns &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -485,17 +501,36 @@ void runOnOperation() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - populateVectorUnrollPatterns( - patterns, - UnrollVectorOptions() - .setNativeShape(ArrayRef{2, 2}) - .setFilterConstraint([](Operation *op) { - return success( - isa(op)); - })); + UnrollVectorOptions opts; + opts.setNativeShape(ArrayRef{2, 2}) + .setFilterConstraint([](Operation *op) { + return success( + isa(op)); + }); + if (reverseUnrollOrder.getValue()) { + opts.setUnrollTraversalOrderFn( + [](Operation *op) -> Optional> { + int64_t numLoops = 0; + if (auto readOp = dyn_cast(op)) + numLoops = readOp.getVectorType().getRank(); + else if (auto writeOp = dyn_cast(op)) + numLoops = writeOp.getVectorType().getRank(); + else + return None; + auto order = llvm::reverse(llvm::seq(0, numLoops)); + return llvm::to_vector(order); + }); + } + populateVectorUnrollPatterns(patterns, opts); populateVectorToVectorCanonicalizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } + + Option reverseUnrollOrder{ + *this, "reverse-unroll-order", + llvm::cl::desc( + "reverse the order of unrolling of vector transfer operations"), + llvm::cl::init(false)}; }; struct TestVectorTransferFullPartialSplitPatterns