diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -449,6 +449,23 @@ /// Return the dimensions of the result vector that were formerly ones in the /// source tensor and thus correspond to "dim-1" broadcasting. llvm::SetVector computeBroadcastedUnitDims(); + + /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the + /// `broadcastedDims` dimensions in the dstShape are broadcasted. + /// This requires (and asserts) that the broadcast is free of dim-1 + /// broadcasting. + /// Since vector.broadcast only allows expanding leading dimensions, an extra + /// vector.transpose may be inserted to make the broadcast possible. + /// `value`, `dstShape` and `broadcastedDims` must be properly specified or + /// the helper will assert. This means: + /// 1. `dstShape` must not be empty. + /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)] + /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims` + // must match the `value` shape. + static Value createOrFoldBroadcastOp( + OpBuilder &b, Value value, + ArrayRef dstShape, + const llvm::SetVector &broadcastedDims); }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1725,13 +1725,9 @@ /// Return the dimensions of the result vector that were formerly ones in the /// source tensor and thus correspond to "dim-1" broadcasting. -llvm::SetVector BroadcastOp::computeBroadcastedUnitDims() { - VectorType srcVectorType = getSourceType().dyn_cast(); - // Scalar broadcast is without any unit dim broadcast. - if (!srcVectorType) - return {}; - ArrayRef srcShape = srcVectorType.getShape(); - ArrayRef dstShape = getVectorType().getShape(); +static llvm::SetVector +computeBroadcastedUnitDims(ArrayRef srcShape, + ArrayRef dstShape) { int64_t rankDiff = dstShape.size() - srcShape.size(); int64_t dstDim = rankDiff; llvm::SetVector res; @@ -1745,6 +1741,129 @@ return res; } +llvm::SetVector BroadcastOp::computeBroadcastedUnitDims() { + // Scalar broadcast is without any unit dim broadcast. + auto srcVectorType = getSourceType().dyn_cast(); + if (!srcVectorType) + return {}; + return ::computeBroadcastedUnitDims(srcVectorType.getShape(), + getVectorType().getShape()); +} + +static bool allBitsSet(llvm::SmallBitVector &bv, int64_t lb, int64_t ub) { + for (int64_t i = lb; i < ub; ++i) + if (!bv.test(i)) + return false; + return true; +} + +/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the +/// `broadcastedDims` dimensions in the dstShape are broadcasted. +/// This requires (and asserts) that the broadcast is free of dim-1 +/// broadcasting. +/// Since vector.broadcast only allows expanding leading dimensions, an extra +/// vector.transpose may be inserted to make the broadcast possible. +/// `value`, `dstShape` and `broadcastedDims` must be properly specified or +/// the helper will assert. This means: +/// 1. `dstShape` must not be empty. +/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)] +/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims` +// must match the `value` shape. +Value BroadcastOp::createOrFoldBroadcastOp( + OpBuilder &b, Value value, ArrayRef dstShape, + const llvm::SetVector &broadcastedDims) { + assert(!dstShape.empty() && "unexpected empty dst shape"); + + // Well-formedness check. + SmallVector checkShape; + for (int i = 0, e = dstShape.size(); i < e; ++i) { + if (broadcastedDims.contains(i)) + continue; + checkShape.push_back(dstShape[i]); + } + assert(broadcastedDims.size() == dstShape.size() - checkShape.size() && + "ill-formed broadcastedDims contains values not confined to " + "destVectorShape"); + + Location loc = value.getLoc(); + Type elementType = getElementTypeOrSelf(value.getType()); + VectorType srcVectorType = value.getType().dyn_cast(); + VectorType dstVectorType = VectorType::get(dstShape, elementType); + + // Step 2. If scalar -> dstShape broadcast, just do it. + if (!srcVectorType) { + assert(checkShape.empty() && + "ill-formed createOrFoldBroadcastOp arguments"); + return b.createOrFold(loc, dstVectorType, value); + } + + assert(srcVectorType.getShape().equals(checkShape) && + "ill-formed createOrFoldBroadcastOp arguments"); + + // Step 3. Since vector.broadcast only allows creating leading dims, + // vector -> dstShape broadcast may require a transpose. + // Traverse the dims in order and construct: + // 1. The leading entries of the broadcastShape that is guaranteed to be + // achievable by a simple broadcast. + // 2. The induced permutation for the subsequent vector.transpose that will + // bring us from `broadcastShape` back to he desired `dstShape`. + // If the induced permutation is not the identity, create a vector.transpose. + SmallVector broadcastShape, permutation(dstShape.size(), -1); + broadcastShape.reserve(dstShape.size()); + // Consider the example: + // srcShape = 2x4 + // dstShape = 1x2x3x4x5 + // broadcastedDims = [0, 2, 4] + // + // We want to build: + // broadcastShape = 1x3x5x2x4 + // permutation = [0, 2, 4, 1, 3] + // ---V--- -----V----- + // leading broadcast part src shape part + // + // Note that the trailing dims of broadcastShape are exactly the srcShape + // by construction. + // nextSrcShapeDim is used to keep track of where in the permutation the + // "src shape part" occurs. + int64_t nextSrcShapeDim = broadcastedDims.size(); + for (int64_t i = 0, e = dstShape.size(); i < e; ++i) { + if (broadcastedDims.contains(i)) { + // 3.a. For each dim in the dst shape, if it is a broadcasted dim, + // bring it to the head of the broadcastShape. + // It will need to be permuted back from `broadcastShape.size() - 1` into + // position `i`. + broadcastShape.push_back(dstShape[i]); + permutation[i] = broadcastShape.size() - 1; + } else { + // 3.b. Otherwise, the dim is not broadcasted, it comes from the src + // shape and needs to be permuted into position `i`. + // Don't touch `broadcastShape` here, the whole srcShape will be + // appended after. + permutation[i] = nextSrcShapeDim++; + } + } + // 3.c. Append the srcShape. + llvm::append_range(broadcastShape, srcVectorType.getShape()); + + // Ensure there are no dim-1 broadcasts. + assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape) + .empty() && + "unexpected dim-1 broadcast"); + + VectorType broadcastType = VectorType::get(broadcastShape, elementType); + assert(vector::isBroadcastableTo(value.getType(), broadcastType) == + vector::BroadcastableToResult::Success && + "must be broadcastable"); + Value res = b.createOrFold(loc, broadcastType, value); + // Step 4. If we find any dimension that indeed needs to be permuted, + // immediately return a new vector.transpose. + for (int64_t i = 0, e = permutation.size(); i < e; ++i) + if (permutation[i] != i) + return b.createOrFold(loc, res, permutation); + // Otherwise return res. + return res; +} + BroadcastableToResult mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair *mismatchingDims) { diff --git a/mlir/test/Dialect/Vector/test-create-broadcast.mlir b/mlir/test/Dialect/Vector/test-create-broadcast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/test-create-broadcast.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s --test-create-vector-broadcast --allow-unregistered-dialect --split-input-file | FileCheck %s + +func.func @foo(%a : f32) -> vector<1x2xf32> { + %0 = "test_create_broadcast"(%a) {broadcast_dims = array} : (f32) -> vector<1x2xf32> + // CHECK: vector.broadcast {{.*}} : f32 to vector<1x2xf32> + // CHECK-NOT: vector.transpose + return %0: vector<1x2xf32> +} + +// ----- + +func.func @foo(%a : vector<2x2xf32>) -> vector<2x2x3xf32> { + %0 = "test_create_broadcast"(%a) {broadcast_dims = array} + : (vector<2x2xf32>) -> vector<2x2x3xf32> + // CHECK: vector.broadcast {{.*}} : vector<2x2xf32> to vector<3x2x2xf32> + // CHECK: vector.transpose {{.*}}, [1, 2, 0] : vector<3x2x2xf32> to vector<2x2x3xf32> + return %0: vector<2x2x3xf32> +} + +// ----- + +func.func @foo(%a : vector<3x3xf32>) -> vector<4x3x3xf32> { + %0 = "test_create_broadcast"(%a) {broadcast_dims = array} + : (vector<3x3xf32>) -> vector<4x3x3xf32> + // CHECK: vector.broadcast {{.*}} : vector<3x3xf32> to vector<4x3x3xf32> + // CHECK-NOT: vector.transpose + return %0: vector<4x3x3xf32> +} + +// ----- + +func.func @foo(%a : vector<2x4xf32>) -> vector<1x2x3x4x5xf32> { + %0 = "test_create_broadcast"(%a) {broadcast_dims = array} + : (vector<2x4xf32>) -> vector<1x2x3x4x5xf32> + // CHECK: vector.broadcast {{.*}} : vector<2x4xf32> to vector<1x3x5x2x4xf32> + // CHECK: vector.transpose {{.*}}, [0, 3, 1, 4, 2] : vector<1x3x5x2x4xf32> to vector<1x2x3x4x5xf32> + return %0: vector<1x2x3x4x5xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -820,6 +820,38 @@ } }; +struct TestCreateVectorBroadcast + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast) + + StringRef getArgument() const final { return "test-create-vector-broadcast"; } + StringRef getDescription() const final { + return "Test optimization transformations for transfer ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + getOperation()->walk([](Operation *op) { + if (op->getName().getStringRef() != "test_create_broadcast") + return; + auto targetShape = + op->getResult(0).getType().cast().getShape(); + auto arrayAttr = + op->getAttr("broadcast_dims").cast().asArrayRef(); + llvm::SetVector broadcastedDims; + broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end()); + OpBuilder b(op); + Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp( + b, op->getOperand(0), targetShape, broadcastedDims); + op->getResult(0).replaceAllUsesWith(bcast); + op->erase(); + }); + } +}; + } // namespace namespace mlir { @@ -856,6 +888,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir