diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -128,6 +128,19 @@ }; return *this; } + + /// Function that returns the traversal order (in terms of "for loop order", + /// i.e. slowest varying dimension to fastest varying dimension) that shoudl + /// be used when unrolling the given operation into units of the native vector + /// size. + using UnrollTraversalOrderFnType = + std::function>(Operation *op)>; + UnrollTraversalOrderFnType traversalOrderCallback = nullptr; + UnrollVectorOptions & + setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) { + traversalOrderCallback = std::move(traversalOrderFn); + return *this; + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp @@ -15,8 +15,11 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Support/MathExtras.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include #define DEBUG_TYPE "vector-unrolling" @@ -36,6 +39,65 @@ return elementOffsets; } +/// A functor that accomplishes the same thing as `getVectorOffset` but allows +/// for reordering the traversal of the dimensions. The order of traversal is +/// given in "for loop order" (outer to inner). +namespace { +class DecomposeShapeIterator { +private: + SmallVector vectorShape; + SmallVector loopOrder; + SmallVector sliceStrides; + int64_t maxIndexVal{1}; + +public: + DecomposeShapeIterator(const SmallVector &originalShape, + const SmallVector &targetShape, + const SmallVector &loopOrder) + : vectorShape(targetShape), loopOrder(loopOrder), + sliceStrides(originalShape.size()) { + // Compute the count for each dimension. + SmallVector sliceDimCounts(originalShape.size()); + for (unsigned r = 0; r < originalShape.size(); ++r) { + sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]); + maxIndexVal *= sliceDimCounts[r]; + } + + // Reversing "loop order" gives dimensions from fastest varying to slowest + // varying (smallest stride to largest stride). + int64_t accum = 1; + for (auto idx : llvm::reverse(loopOrder)) { + sliceStrides[idx] = accum; + accum *= sliceDimCounts[idx]; + } + } + + // Turn the linear index into a d-tuple based on units of vectors of size + // `vectorShape`. The linear index is assumed to represent traversal of the + // dimensions based on `order`. + SmallVector delinearize(int64_t index) const { + // Traverse in for loop order (largest stride to smallest stride). + SmallVector vectorOffsets(sliceStrides.size()); + for (auto idx : loopOrder) { + vectorOffsets[idx] = index / sliceStrides[idx]; + index %= sliceStrides[idx]; + } + return vectorOffsets; + } + + int64_t maxIndex() const { return maxIndexVal; } + + /// Return the offset within d-tuple based on the ordering given by + /// `loopOrder`. + SmallVector operator()(int64_t index) const { + SmallVector vectorOffsets = delinearize(index); + SmallVector elementOffsets = + computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets); + return elementOffsets; + } +}; +} // namespace + /// Compute the indices of the slice `index` for a tranfer op. static SmallVector sliceTransferIndices(int64_t index, ArrayRef originalShape, @@ -239,7 +301,6 @@ SmallVector ratio = *shapeRatio(originalSize, *targetShape); // Compute shape ratio of 'shape' and 'sizes'. - int64_t sliceCount = computeMaxLinearIndex(ratio); Location loc = contractOp.getLoc(); unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex]; @@ -247,9 +308,22 @@ SmallVector, Value, llvm::DenseMap, unsigned, OffsetMapInfo>> accCache; + + SmallVector loopOrder = llvm::to_vector(llvm::seq( + 0, static_cast(contractOp.getIndexingMaps().size()))); + if (options.traversalOrderCallback != nullptr) { + Optional> order = + options.traversalOrderCallback(contractOp); + if (order.hasValue()) { + loopOrder = std::move(*order); + } + } + + DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, + loopOrder); + const int64_t sliceCount = indexToOffsets.maxIndex(); for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = - getVectorOffset(originalSize, *targetShape, i); + SmallVector offsets = indexToOffsets(i); SmallVector slicesOperands(contractOp.getNumOperands()); // Helper to coompute the new shape of each operand and extract the slice. diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -1,50 +1,156 @@ // RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s +// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=2,0,1" --split-input-file | FileCheck %s --check-prefix=ORDER -func.func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>, +func.func @vector_contract_f32(%lhs : vector<8x4xf32>, %rhs : vector<8x4xf32>, %init : vector<8x8xf32>) -> vector<8x8xf32> { %0 = vector.contract {indexing_maps = [affine_map<(i, j, k) -> (i, k)>, affine_map<(i, j, k) -> (j, k)>, affine_map<(i, j, k) -> (i, j)>], iterator_types = ["parallel", "parallel", "reduction"]} - %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32> + %lhs, %rhs, %init : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32> return %0 : vector<8x8xf32> } // CHECK-LABEL: func @vector_contract_f32 -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { -// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { +// CHECK-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32> + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [0, 0] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [0, 0] +// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// CHECK-SAME: offsets = [0, 0] +// CHECK: [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [0, 2] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [0, 2] +// CHECK: [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [0, 0] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [4, 0] +// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// CHECK-SAME: offsets = [0, 4] +// CHECK: [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [0, 2] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [4, 2] +// CHECK: [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [4, 0] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [0, 0] +// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// CHECK-SAME: offsets = [4, 0] +// CHECK: [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [4, 2] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [0, 2] +// CHECK: [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum5]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [4, 0] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [4, 0] +// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// CHECK-SAME: offsets = [4, 4] +// CHECK: [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> -// CHECK: vector.contract { + +// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// CHECK-SAME: offsets = [4, 2] +// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// CHECK-SAME: offsets = [4, 2] +// CHECK: [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum7]] // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + // CHECK: return +// ORDER-LABEL: func @vector_contract_f32 +// ORDER-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [0, 0] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [0, 0] +// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// ORDER-SAME: offsets = [0, 0] +// ORDER: [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [0, 0] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [4, 0] +// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// ORDER-SAME: offsets = [0, 4] +// ORDER: [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [4, 0] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [0, 0] +// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// ORDER-SAME: offsets = [4, 0] +// ORDER: [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [4, 0] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [4, 0] +// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] +// ORDER-SAME: offsets = [4, 4] +// ORDER: [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [0, 2] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [0, 2] +// ORDER: [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [0, 2] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [4, 2] +// ORDER: [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum2]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [4, 2] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [0, 2] +// ORDER: [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] +// ORDER-SAME: offsets = [4, 2] +// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] +// ORDER-SAME: offsets = [4, 2] +// ORDER: [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum4]] +// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> + +// ORDER: return + + + func.func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>, %init : vector<8x8xf16>) -> vector<8x8xf16> { %0 = vector.contract @@ -158,3 +264,4 @@ // CHECK: %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> // CHECK: return %[[V7]] : vector<2x3x8x4xf32> + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -322,12 +322,18 @@ } return nativeShape; }; - populateVectorUnrollPatterns(patterns, - UnrollVectorOptions() - .setNativeShapeFn(nativeShapeFn) - .setFilterConstraint([](Operation *op) { - return success(isa(op)); - })); + + UnrollVectorOptions opts; + opts.setNativeShapeFn(nativeShapeFn) + .setFilterConstraint( + [](Operation *op) { return success(isa(op)); }); + if (!unrollOrder.empty()) { + opts.setUnrollTraversalOrderFn([this](Operation *op) + -> Optional> { + return SmallVector{unrollOrder.begin(), unrollOrder.end()}; + }); + } + populateVectorUnrollPatterns(patterns, opts); } else { populateVectorUnrollPatterns( patterns, UnrollVectorOptions() @@ -340,6 +346,10 @@ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } + ListOption unrollOrder{*this, "unroll-order", + llvm::cl::desc("set the unroll order"), + llvm::cl::ZeroOrMore}; + Option unrollBasedOnType{ *this, "unroll-based-on-type", llvm::cl::desc("Set the unroll factor based on type of the operation"),