diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -449,6 +449,17 @@ /// Return the dimensions of the result vector that were formerly ones in the /// source tensor and thus correspond to "dim-1" broadcasting. llvm::SetVector computeBroadcastedUnitDims(); + + /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the + /// `broadcastedDims` dimensions in the dstShape are broadcasted. + /// This requires (and asserts) that the broadcast is free of dim-1 + /// broadcasting. + /// Since vector.broadcast only allows expanding leading dimensions, an extra + /// vector.transpose may be inserted to make the broadcast possible. + static Value createBroadcastOp( + OpBuilder &b, Value value, + ArrayRef dstShape, + llvm::SetVector &broadcastedDims); }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1699,15 +1699,9 @@ // BroadcastOp //===----------------------------------------------------------------------===// -/// Return the dimensions of the result vector that were formerly ones in the -/// source tensor and thus correspond to "dim-1" broadcasting. -llvm::SetVector BroadcastOp::computeBroadcastedUnitDims() { - VectorType srcVectorType = getSourceType().dyn_cast(); - // Scalar broadcast is without any unit dim broadcast. - if (!srcVectorType) - return {}; - ArrayRef srcShape = srcVectorType.getShape(); - ArrayRef dstShape = getVectorType().getShape(); +static llvm::SetVector +computeBroadcastedUnitDims(ArrayRef srcShape, + ArrayRef dstShape) { int64_t rankDiff = dstShape.size() - srcShape.size(); int64_t dstDim = rankDiff; llvm::SetVector res; @@ -1721,6 +1715,114 @@ return res; } +/// Return the dimensions of the result vector that were formerly ones in the +/// source tensor and thus correspond to "dim-1" broadcasting. +llvm::SetVector BroadcastOp::computeBroadcastedUnitDims() { + // Scalar broadcast is without any unit dim broadcast. + auto srcVectorType = getSourceType().dyn_cast(); + if (!srcVectorType) + return {}; + return ::computeBroadcastedUnitDims(srcVectorType.getShape(), + getVectorType().getShape()); +} + +static bool allBitsSet(llvm::SmallBitVector &bv, int64_t lb, int64_t ub) { + for (int64_t i = lb; i < ub; ++i) + if (!bv.test(i)) + return false; + return true; +} + +/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the +/// `broadcastedDims` dimensions in the dstShape are broadcasted. +/// This requires (and asserts) that the broadcast is free of dim-1 +/// broadcasting. +/// Since `vector.broadcast` only allows expanding leading dimensions, an extra +/// `vector.transpose` may be inserted to make the broadcast possible. +Value BroadcastOp::createBroadcastOp( + OpBuilder &b, Value value, ArrayRef dstShape, + llvm::SetVector &broadcastedDims) { + // Step 1. If no dstShape to broadcast to, just return `value`. + if (dstShape.empty()) + return value; + + Location loc = value.getLoc(); + Type elementType = getElementTypeOrSelf(value.getType()); + VectorType srcVectorType = value.getType().dyn_cast(); + VectorType dstVectorType = VectorType::get(dstShape, elementType); + + // Step 2. If scalar -> dstShape broadcast, just do it. + if (!srcVectorType) { + // Well-formedness check, all broadcastedDims must be set. + for (int64_t i = 0, e = broadcastedDims.size(); i < e; ++i) + assert(broadcastedDims.contains(i) && "broadcastedDims is ill-formed"); + return b.createOrFold(loc, dstVectorType, value); + } + + // Step 3. Since vector.broadcast only allows creating leading dims, + // vector -> dstShape broadcast may require a transpose. + // Traverse the dims in order and construct: + // 1. The leading entries of the broadcastShape that is guaranteed to be + // achievable by a simple broadcast. + // 2. The induced permutation for the subsequent vector.transpose that will + // bring us from `broadcastShape` back to he desired `dstShape`. + // If the induced permutation is not the identity, create a vector.transpose. + SmallVector broadcastShape, permutation(dstShape.size(), -1); + broadcastShape.reserve(dstShape.size()); + // Consider the example: + // srcShape = 2x4 + // dstShape = 1x2x3x4x5 + // broadcastedDims = [0, 2, 4] + // + // We want to build: + // broadcastShape = 1x3x5x2x4 + // permutation = [0, 2, 4, 1, 3] + // ---V--- -----V----- + // leading broadcast part src shape part + // + // Note that the trailing dims of broadcastShape are exactly the srcShape + // by construction. + // nextsrcShapeDim is used to keep track of where in the permutation the + // "src shape part" occurs. + int64_t nextsrcShapeDim = broadcastedDims.size(); + for (int64_t i = 0, e = dstShape.size(); i < e; ++i) { + if (broadcastedDims.contains(i)) { + // 3.a. For each dim in the dst shape, if it is a broadcasted dim, + // bring it to the head of the broadcastShape. + // It will need to be permuted back from `broadcastShape.size() - 1` into + // position `i`. + broadcastShape.push_back(dstShape[i]); + permutation[i] = broadcastShape.size() - 1; + } else { + // 3.b. Otherwise, the dim is not broadcasted, it comes from the src + // shape and needs to be permuted into position `i`. + // Don't touch `broadcastShape` here, the whole srcShape will be + // appended after. + permutation[i] = nextsrcShapeDim++; + } + } + // 3.c. Append the srcShape. + llvm::append_range(broadcastShape, srcVectorType.getShape()); + + // Ensure there are no dim-1 broadcasts. + assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape) + .empty() && + "unexpected dim-1 broadcast"); + + VectorType broadcastType = VectorType::get(broadcastShape, elementType); + assert(vector::isBroadcastableTo(value.getType(), broadcastType) == + vector::BroadcastableToResult::Success && + "must be broadcastable"); + Value res = b.createOrFold(loc, broadcastType, value); + // Step 4. If we find any dimension that indeed needs to be permuted, + // immediately return a new vector.transpose. + for (int64_t i = 0, e = permutation.size(); i < e; ++i) + if (permutation[i] != i) + return b.createOrFold(loc, res, permutation); + // Otherwise return res. + return res; +} + BroadcastableToResult mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair *mismatchingDims) { diff --git a/mlir/test/Dialect/Vector/test-create-broadcast.mlir b/mlir/test/Dialect/Vector/test-create-broadcast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/test-create-broadcast.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s -test-create-vector-broadcast --split-input-file | FileCheck %s + +// The test hardcodes scalar is broadcasted to targetShape 1x2 with +// broadcastedDims = [0, 1]. +func.func @foo(%a : f32) { + // CHECK: vector.broadcast {{.*}} : f32 to vector<1x2xf32> + // CHECK-NOT: vector.transpose + return +} + +// ----- + +// The test hardcodes 2x2 is broadcasted to targetShape 2x2x3xf32 with +// broadcastedDims = [2]. +func.func @foo(%a : vector<2x2xf32>) { + // CHECK: vector.broadcast {{.*}} : vector<2x2xf32> to vector<3x2x2xf32> + // CHECK: vector.transpose {{.*}}, [1, 2, 0] : vector<3x2x2xf32> to vector<2x2x3xf32> + return +} + +// ----- + +// The test hardcodes 3x3 is broadcasted to targetShape 4x3x3xf32 with +// broadcastedDims = [0]. +func.func @foo(%a : vector<3x3xf32>) { + // CHECK: vector.broadcast {{.*}} : vector<3x3xf32> to vector<4x3x3xf32> + // CHECK-NOT: vector.transpose + return +} + +// ----- + +// The test hardcodes 2x4 is broadcasted to targetShape 1x2x3x4x5 with +// broadcastedDims = [0, 2, 4]. +func.func @foo(%a : vector<2x4xf32>) { + // CHECK: vector.broadcast {{.*}} : vector<2x4xf32> to vector<1x3x5x2x4xf32> + // CHECK: vector.transpose {{.*}}, [0, 3, 1, 4, 2] : vector<1x3x5x2x4xf32> to vector<1x2x3x4x5xf32> + return +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -22,10 +22,13 @@ #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::linalg; @@ -820,6 +823,51 @@ } }; +struct TestCreateVectorBroadcast + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast) + + StringRef getArgument() const final { return "test-create-vector-broadcast"; } + StringRef getDescription() const final { + return "Test optimization transformations for transfer ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + OpBuilder b(getOperation()); + b.setInsertionPointToStart(&getOperation().getBody().front()); + for (BlockArgument bbArg : getOperation().getArguments()) { + auto vectorType = bbArg.getType().dyn_cast(); + llvm::SetVector broadcastedDims; + SmallVector targetShape; + if (!vectorType) { + targetShape = {1, 2}; + ArrayRef dims{0, 1}; + broadcastedDims.insert(dims.begin(), dims.end()); + } else if (vectorType.getShape().equals({2, 2})) { + targetShape = {2, 2, 3}; + ArrayRef dims{2}; + broadcastedDims.insert(dims.begin(), dims.end()); + } else if (vectorType.getShape().equals({3, 3})) { + targetShape = {4, 3, 3}; + ArrayRef dims{0}; + broadcastedDims.insert(dims.begin(), dims.end()); + } else if (vectorType.getShape().equals({2, 4})) { + targetShape = {1, 2, 3, 4, 5}; + ArrayRef dims{0, 2, 4}; + broadcastedDims.insert(dims.begin(), dims.end()); + } else { + continue; + } + vector::BroadcastOp::createBroadcastOp(b, bbArg, targetShape, + broadcastedDims); + } + } +}; + } // namespace namespace mlir { @@ -856,6 +904,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir