diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -608,6 +608,11 @@ RankedTensorType getType() { return getResult().getType().cast(); } + + /// The `dest` type is the same as the result type. + RankedTensorType getDestType() { + return getType(); + } /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. @@ -1117,6 +1122,10 @@ return getSource().getType().cast(); } + RankedTensorType getDestType() { + return getDest().getType().cast(); + } + ParallelCombiningOpInterface getParallelCombiningParent() { return dyn_cast( getOperation()->getParentOp()); @@ -1125,7 +1134,7 @@ /// Return the expected rank of each of the `static_offsets`, `static_sizes` /// and `static_strides` attributes. std::array getArrayAttrMaxRanks() { - unsigned rank = getSourceType().getRank(); + unsigned rank = getDestType().getRank(); return {rank, rank, rank}; } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1123,9 +1123,9 @@ /// Verifier for ExtractSliceOp. LogicalResult ExtractSliceOp::verify() { // Verify result type against inferred type. - auto expectedType = ExtractSliceOp::inferResultType( + RankedTensorType expectedType = ExtractSliceOp::inferResultType( getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides()); - auto result = isRankReducedType(expectedType.cast(), getType()); + SliceVerificationResult result = isRankReducedType(expectedType, getType()); return produceSliceErrorMsg(result, *this, expectedType); } @@ -1487,17 +1487,18 @@ build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +/// Rank-reducing type verification for both InsertSliceOp and +/// ParallelInsertSliceOp. static SliceVerificationResult verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, ArrayAttr staticOffsets, ArrayAttr staticSizes, ArrayAttr staticStrides, ShapedType *expectedType = nullptr) { // insert_slice is the inverse of extract_slice, use the same type inference. - auto expected = ExtractSliceOp::inferResultType( - dstType, extractFromI64ArrayAttr(staticOffsets), - extractFromI64ArrayAttr(staticSizes), - extractFromI64ArrayAttr(staticStrides)) - .cast(); + RankedTensorType expected = ExtractSliceOp::inferResultType( + dstType, extractFromI64ArrayAttr(staticOffsets), + extractFromI64ArrayAttr(staticSizes), + extractFromI64ArrayAttr(staticStrides)); if (expectedType) *expectedType = expected; return isRankReducedType(expected, srcType); @@ -1506,7 +1507,7 @@ /// Verifier for InsertSliceOp. LogicalResult InsertSliceOp::verify() { ShapedType expectedType; - auto result = + SliceVerificationResult result = verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(), getStaticSizes(), getStaticStrides(), &expectedType); return produceSliceErrorMsg(result, *this, expectedType); @@ -1514,6 +1515,7 @@ /// If we have two consecutive InsertSliceOp writing to the same slice, we /// can mutate the second InsertSliceOp's destination to the first one's. +/// This works similarly when the second op is a ParallelInsertSliceOp. /// /// Example: /// @@ -1527,8 +1529,11 @@ /// ```mlir /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1] /// ``` -static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) { - auto prevInsertOp = insertOp.getDest().getDefiningOp(); +/// +/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. +template +static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) { + auto prevInsertOp = insertOp.getDest().template getDefiningOp(); auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; if (!prevInsertOp || @@ -1540,14 +1545,32 @@ return success(); } -OpFoldResult InsertSliceOp::fold(ArrayRef) { - if (getSourceType().hasStaticShape() && getType().hasStaticShape() && - getSourceType() == getType() && - succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) - return this->getSource(); - if (succeeded(foldInsertAfterInsertSlice(*this))) - return getResult(); - return OpFoldResult(); +/// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return +/// type varies though so we wrap it in a FailureOr. +/// +/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. +template +FailureOr foldInsertOp(InsertOpTy insertOp, ArrayRef) { + if (insertOp.getSourceType().hasStaticShape() && + insertOp.getDestType().hasStaticShape() && + insertOp.getSourceType() == insertOp.getDestType() && + succeeded(foldIdentityOffsetSizeAndStrideOpInterface( + insertOp, insertOp.getDestType()))) + return static_cast(insertOp.getSource()); + if (succeeded(foldInsertAfterInsertSlice(insertOp))) { + // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should + // return OpFoldResult(). + if (std::is_same::value) + return static_cast(insertOp->getResult(0)); + else + return OpFoldResult(); + } + return failure(); +} + +OpFoldResult InsertSliceOp::fold(ArrayRef operands) { + auto maybeOpFoldResult = foldInsertOp(*this, operands); + return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult; } LogicalResult InsertSliceOp::reifyResultShapes( @@ -1562,12 +1585,15 @@ namespace { /// Pattern to rewrite a insert_slice op with constant arguments. +/// +/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. +template class InsertSliceOpConstantArgumentFolder final - : public OpRewritePattern { + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, PatternRewriter &rewriter) const override { // No constant operand, just return. if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { @@ -1587,13 +1613,20 @@ // Create the new op in canonical form. auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType( - insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(), + insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(), mixedOffsets, mixedSizes, mixedStrides); Value toInsert = insertSliceOp.getSource(); - if (sourceType != insertSliceOp.getSourceType()) + if (sourceType != insertSliceOp.getSourceType()) { + OpBuilder::InsertionGuard g(rewriter); + // The only difference between InsertSliceOp and ParallelInsertSliceOp is + // the the insertion point is just before the ParallelCombiningOp in the + // parallel case. + if (std::is_same::value) + rewriter.setInsertionPoint(insertSliceOp->getParentOp()); toInsert = rewriter.create(insertSliceOp.getLoc(), sourceType, toInsert); - rewriter.replaceOpWithNewOp( + } + rewriter.replaceOpWithNewOp( insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets, mixedSizes, mixedStrides); return success(); @@ -1618,10 +1651,13 @@ /// Note: When folding a cast on the destination tensor, the result of the /// insert_slice operation is casted to ensure that the type of the result did /// not change. -struct InsertSliceOpCastFolder final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// +/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. +template +struct InsertSliceOpCastFolder final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, PatternRewriter &rewriter) const override { if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) { return matchPattern(operand, matchConstantIndex()); @@ -1643,24 +1679,27 @@ auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource()); auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest()); - - auto srcType = src.getType().cast(); - auto dstType = dst.getType().cast(); + auto srcType = src.getType().template cast(); + auto dstType = dst.getType().template cast(); if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(), insertSliceOp.getStaticSizes(), insertSliceOp.getStaticStrides()) != SliceVerificationResult::Success) return failure(); - Value replacement = rewriter.create( + Operation *replacement = rewriter.create( insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); - if (replacement.getType() != insertSliceOp.getType()) { - replacement = rewriter.create( - insertSliceOp.getLoc(), insertSliceOp.getType(), replacement); + // In the parallel case there is no result and so nothing to cast. + bool isParallelInsert = + std::is_same::value; + if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) { + replacement = rewriter.create(insertSliceOp.getLoc(), + insertSliceOp.getDestType(), + replacement->getResult(0)); } - rewriter.replaceOp(insertSliceOp, replacement); + rewriter.replaceOp(insertSliceOp, replacement->getResults()); return success(); } }; @@ -1684,14 +1723,17 @@ /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1] /// : tensor<64x64xf32> into ... /// ``` +/// +/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp. +template struct InsertSliceOpSourceCastInserter final - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, PatternRewriter &rewriter) const override { RankedTensorType srcType = insertSliceOp.getSourceType(); - if (srcType.getRank() != insertSliceOp.getType().getRank()) + if (srcType.getRank() != insertSliceOp.getDestType().getRank()) return failure(); SmallVector newSrcShape(srcType.getShape().begin(), srcType.getShape().end()); @@ -1713,12 +1755,19 @@ // 2) "More static" than srcType. // 3) Cast-compatible with srcType. // Insert the cast. + OpBuilder::InsertionGuard g(rewriter); + // The only difference between InsertSliceOp and ParallelInsertSliceOp is + // the the insertion point is just before the ParallelCombiningOp in the + // parallel case. + if (std::is_same::value) + rewriter.setInsertionPoint(insertSliceOp->getParentOp()); Value cast = rewriter.create( insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource()); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( insertSliceOp, cast, insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + cast.getDefiningOp()->getParentOfType().dump(); return success(); } }; @@ -1726,8 +1775,9 @@ void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add, + InsertSliceOpCastFolder, + InsertSliceOpSourceCastInserter>(context); } Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b, @@ -2234,7 +2284,12 @@ if (!isa(getOperation()->getParentOp())) return this->emitError("expected ParallelCombiningOpInterface parent, got:") << *(getOperation()->getParentOp()); - return success(); + + ShapedType expectedType; + SliceVerificationResult result = + verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(), + getStaticSizes(), getStaticStrides(), &expectedType); + return produceSliceErrorMsg(result, *this, expectedType); } namespace { @@ -2263,51 +2318,37 @@ canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. + auto sourceType = + tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( + insertSliceOp.getSourceType().getRank(), + insertSliceOp.getDestType(), mixedOffsets, mixedSizes, + mixedStrides); + Value toInsert = insertSliceOp.getSource(); + if (sourceType != insertSliceOp.getSourceType()) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(insertSliceOp->getParentOp()); + toInsert = rewriter.create(insertSliceOp.getLoc(), + sourceType, toInsert); + } rewriter.replaceOpWithNewOp( - insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(), - mixedOffsets, mixedSizes, mixedStrides); + insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets, + mixedSizes, mixedStrides); return success(); } }; } // namespace -/// Fold a parallel_insert_slice source coming from a tensor.cast op. -/// -/// Example: -/// ``` -/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) { -/// %1 = compute_some_tensor() : tensor<64xf32> -/// %2 = tensor.cast %1 : tensor<64xf32> to tensor -/// scf.foreach_thread.perform_concurrently { -/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] : -/// tensor into tensor<128xf32> -/// } -/// } -/// ``` -/// -/// is folded into: -/// ``` -/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) { -/// %1 = compute_some_tensor() : tensor<64xf32> -/// scf.foreach_thread.perform_concurrently { -/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] : -/// tensor<64xf32> into tensor<128xf32> -/// } -/// } -/// ``` LogicalResult ParallelInsertSliceOp::fold(ArrayRef operands, SmallVectorImpl &results) { - auto sourceCast = getSource().getDefiningOp(); - if (!sourceCast) - return failure(); - getSourceMutable().assign(sourceCast.getSource()); - return success(); + return foldInsertOp(*this, operands); } void ParallelInsertSliceOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add, + InsertSliceOpCastFolder, + InsertSliceOpSourceCastInserter>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1429,23 +1429,25 @@ // ----- // CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices( -// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor, +// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, // CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor, // CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index func.func @canonicalize_parallel_insert_slice_indices( - %arg0 : tensor, %arg1: tensor, + %arg0 : tensor<1x5xf32>, %arg1: tensor, %num_threads : index) -> tensor { %cst = arith.constant 4.200000e+01 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index + // CHECK-NOT: tensor.cast // CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor) { // CHECK-NEXT: scf.foreach_thread.perform_concurrently { // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1] %2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor) { + %3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor into tensor + tensor.parallel_insert_slice %3 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor into tensor } } return %2 : tensor