diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -445,6 +445,10 @@ VectorType getVectorType() { return getVector().getType().cast(); } + + /// Return the dimensions of the result vector that were formerly ones in the + /// source tensor and thus correspond to "dim-1" broadcasting. + llvm::SetVector computeBroadcastedUnitDims(); }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1351,7 +1351,11 @@ auto getRank = [](Type type) { return type.isa() ? type.cast().getRank() : 0; }; + // If splat or broadcast from a scalar, just return the source scalar. unsigned broadcastSrcRank = getRank(source.getType()); + if (broadcastSrcRank == 0) + return source; + unsigned extractResultRank = getRank(extractOp.getType()); if (extractResultRank >= broadcastSrcRank) return Value(); @@ -1362,13 +1366,25 @@ extractVecType.getShape() != broadcastVecType.getShape().take_back(extractResultRank)) return Value(); + + auto broadcastOp = cast(defOp); + int64_t rankDiff = broadcastSrcRank - extractResultRank; + // Detect all the positions that come from "dim-1" broadcasting. + // These dimensions correspond to "dim-1" broadcasted dims; set the mathching + // extract position to `0` when extracting from the source operand. + llvm::SetVector broadcastedUnitDims = + broadcastOp.computeBroadcastedUnitDims(); auto extractPos = extractVector(extractOp.getPosition()); - unsigned rankDiff = broadcastSrcRank - extractResultRank; + for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i) + if (broadcastedUnitDims.contains(i)) + extractPos[i] = 0; + // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the + // matching extract position when extracting from the source operand. extractPos.erase(extractPos.begin(), std::next(extractPos.begin(), extractPos.size() - rankDiff)); - extractOp.setOperand(source); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); + extractOp.setOperand(source); extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(extractPos)); return extractOp.getResult(); @@ -1683,6 +1699,28 @@ // BroadcastOp //===----------------------------------------------------------------------===// +/// Return the dimensions of the result vector that were formerly ones in the +/// source tensor and thus correspond to "dim-1" broadcasting. +llvm::SetVector BroadcastOp::computeBroadcastedUnitDims() { + VectorType srcVectorType = getSourceType().dyn_cast(); + // Scalar broadcast is without any unit dim broadcast. + if (!srcVectorType) + return {}; + ArrayRef srcShape = srcVectorType.getShape(); + ArrayRef dstShape = getVectorType().getShape(); + int64_t rankDiff = dstShape.size() - srcShape.size(); + int64_t dstDim = rankDiff; + llvm::SetVector res; + for (auto [s1, s2] : llvm::zip(srcShape, dstShape.drop_front(rankDiff))) { + if (s1 != s2) { + assert(s1 == 1 && "expected dim-1 broadcasting"); + res.insert(dstDim); + } + ++dstDim; + } + return res; +} + BroadcastableToResult mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair *mismatchingDims) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2020,3 +2020,15 @@ %1 = vector.transfer_read %0[%c0, %i4, %c0], %f0 {in_bounds = [true]} : tensor<1x4x4xf32>, vector<4xf32> return %1 : vector<4xf32> } + +// ----- + +// CHECK-LABEL: func.func @extract_from_broadcast +func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> { + %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32> + + // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1x1x1xf32> + // CHECK-NEXT: return %0 : vector<1xf32> + %1 = vector.extract %0[0, 0, 31] : vector<1x1x32x1xf32> + return %1: vector<1xf32> +}