diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -821,6 +821,45 @@ }]; } +//===----------------------------------------------------------------------===// +// RewriteInDestinationPassingStyleOp. +//===----------------------------------------------------------------------===// + +def RewriteInDestinationPassingStyleOp : Op< + Transform_Dialect, "structured.rewrite_in_destination_passing_style", + [MemoryEffectsOpInterface, + NavigationTransformOpTrait, + DeclareOpInterfaceMethods]> { + let description = [{ + Rewrite a supported tensor operation that is not in destination-passing style + into a form that is in destination-passing style. + Currently supported operations are: + - tensor.pad + - tensor.generate + - tensor.from_elements + This dichotomy hints at a future interface, for now the implementation just + switches between different implementation. + + #### Return modes + + This operation ignores non-unsupported ops and drops them from the return. + If all the operations referred to by the `target` PDLOperation generalize + properly, the transform succeeds. Otherwise the transform silently fails. + The return handle points to a subset of successfully produced operations: + - tensor.pad case, the returned handle points to the tensor.insert_slice. + - tensor.generate case, the returned handle points to the linalg.generic. + - tensor.from_elements case, the returned handle points to the last + tensor.insert. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + let assemblyFormat = [{ + $target attr-dict + `:` functional-type($target, results) + }]; +} + //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1177,6 +1177,20 @@ linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef outerPerm, ArrayRef innerPerm); +/// Rewrite tensor.from_elements to linalg.generic. +FailureOr +rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::FromElementsOp fromElementsOp); + +/// Rewrite tensor.generate to linalg.generic. +FailureOr +rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::GenerateOp generateOp); + +/// Rewrite tensor.pad to linalg.generic + tensor.insert_slice. +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::PadOp padOp); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -39,6 +39,7 @@ #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using namespace mlir; @@ -57,8 +58,8 @@ template static FailureOr tryApply(Operation *operation, Args &&...args) { // Check if the given operation has the type expected by the pattern. - using OpTy = typename llvm::function_traits< - decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; + using OpTy = typename llvm::function_traits::template arg_t<0>; auto op = dyn_cast(operation); if (!op) return failure(); @@ -1895,6 +1896,32 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// RewriteInDestinationPassingStyleOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::RewriteInDestinationPassingStyleOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + SmallVector res; + ArrayRef targetOps = state.getPayloadOps(getTarget()); + for (Operation *target : targetOps) { + IRRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + FailureOr maybeResult = + TypeSwitch>(target) + .Case( + [&rewriter](auto op) { + return rewriteInDestinationPassingStyle(rewriter, op); + }); + if (failed(maybeResult)) + return emitDefaultSilenceableFailure(target); + res.push_back(*maybeResult); + } + results.set(getResult().cast(), res); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -14,6 +14,7 @@ //===----------------------------------------------------------------------===// // #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -48,89 +49,210 @@ return destination; } -namespace { +FailureOr mlir::linalg::rewriteInDestinationPassingStyle( + RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) { -/// Lower tensor.from_elements to a sequence of chained tensor.insert. -struct FromElementsOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + Location loc = fromElementsOp.getLoc(); + RankedTensorType tensorType = + fromElementsOp.getType().cast(); + auto shape = tensorType.getShape(); - LogicalResult matchAndRewrite(FromElementsOp elementsOp, - PatternRewriter &rewriter) const override { - Location loc = elementsOp.getLoc(); - RankedTensorType tensorType = elementsOp.getType().cast(); - auto shape = tensorType.getShape(); - - // Create tensor.empty. - auto emptyOp = rewriter.create(loc, tensorType, ValueRange()); - - // Case: tensor. - if (shape.empty()) { - rewriter.replaceOpWithNewOp( - elementsOp, elementsOp.getElements().front(), emptyOp.getResult(), - ValueRange()); - return success(); - } + // Create tensor.empty. + auto emptyOp = rewriter.create(loc, tensorType, ValueRange()); - // Create constants for the range of possible indices [0, max{shape_i}). - auto maxDim = *std::max_element(shape.begin(), shape.end()); - SmallVector constants; - constants.reserve(maxDim); - for (int i = 0; i < maxDim; ++i) - constants.push_back(rewriter.create(loc, i)); - - // Traverse all elements and create tensor.insert ops. - auto elementIt = elementsOp.getElements().begin(); - SmallVector indices(tensorType.getRank(), constants[0]); - Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(), - shape, constants, elementIt, indices); - - // Replace tensor.from_elements. - rewriter.replaceOp(elementsOp, result); - return success(); + // Case: tensor. + if (shape.empty()) { + auto insertOp = rewriter.replaceOpWithNewOp( + fromElementsOp, fromElementsOp.getElements().front(), + emptyOp.getResult(), ValueRange()); + return insertOp.getOperation(); } -}; -/// Lower tensor.generate to linalg.generic. -struct GenerateOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + // Create constants for the range of possible indices [0, max{shape_i}). + auto maxDim = *std::max_element(shape.begin(), shape.end()); + SmallVector constants; + constants.reserve(maxDim); + for (int i = 0; i < maxDim; ++i) + constants.push_back(rewriter.create(loc, i)); - LogicalResult matchAndRewrite(GenerateOp generateOp, - PatternRewriter &rewriter) const override { - // Only ops with exactly one block are supported. - if (!generateOp.getBody().hasOneBlock()) - return failure(); + // Traverse all elements and create tensor.insert ops. + auto elementIt = fromElementsOp.getElements().begin(); + SmallVector indices(tensorType.getRank(), constants[0]); + Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(), + shape, constants, elementIt, indices); - Location loc = generateOp.getLoc(); - RankedTensorType tensorType = generateOp.getType().cast(); + // Replace tensor.from_elements. + rewriter.replaceOp(fromElementsOp, result); - // Create tensor.empty. - auto emptyOp = rewriter.create(loc, tensorType, - generateOp.getDynamicExtents()); + return result.getDefiningOp(); +} + +FailureOr +mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::GenerateOp generateOp) { + // Only ops with exactly one block are supported. + if (!generateOp.getBody().hasOneBlock()) + return failure(); + + Location loc = generateOp.getLoc(); + RankedTensorType tensorType = generateOp.getType().cast(); + + // Create tensor.empty. + auto emptyOp = + rewriter.create(loc, tensorType, generateOp.getDynamicExtents()); + // Create linalg.generic. + SmallVector iteratorTypes(tensorType.getRank(), + utils::IteratorType::parallel); + SmallVector indexingMaps( + 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); + auto genericOp = rewriter.create( + loc, tensorType, /*inputs=*/ValueRange(), + /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ + indexingMaps, iteratorTypes); + Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, + tensorType.getElementType(), loc); + rewriter.setInsertionPointToStart(body); + SmallVector bbArgReplacements; + for (int64_t i = 0; i < tensorType.getRank(); ++i) + bbArgReplacements.push_back(rewriter.create(loc, i)); + rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements); + + // Update terminator. + auto yieldOp = cast(body->getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + + // Replace tensor.generate. + rewriter.replaceOp(generateOp, genericOp->getResult(0)); + + return genericOp.getOperation(); +} + +FailureOr +mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::PadOp padOp) { + // Only ops with exactly one block are supported. + if (!padOp.getBodyRegion().hasOneBlock()) + return failure(); + + // Create tensor.empty. + Location loc = padOp.getLoc(); + RankedTensorType resultType = padOp.getResultType(); + ReifiedRankedShapedTypeDims reifiedShape; + if (failed(cast(padOp.getOperation()) + .reifyResultShapes(rewriter, reifiedShape))) + return rewriter.notifyMatchFailure( + padOp, "failed to reify tensor.pad op result shape"); + SmallVector dynamicSizes; + for (int64_t i = 0; i < resultType.getRank(); ++i) + if (resultType.isDynamicDim(i)) + dynamicSizes.push_back(reifiedShape[0][i]); + + Value empty = rewriter.create(loc, resultType, dynamicSizes); + + // Examine the yielded value to decide if a linalg.generic is needed or a + // linalg.fill is sufficient. + Value filled; + Value yieldedValue = + cast(padOp.getBody()->getTerminator()).getValue(); + Attribute constYieldedValue; + // Is the yielded value a bbArg defined outside of the PadOp? + bool outsideBbArg = + yieldedValue.isa() && + yieldedValue.cast().getOwner()->getParentOp() != + padOp.getOperation(); + // Is the yielded value an OpResult defined outside of the PadOp? + bool outsideOpResult = + yieldedValue.isa() && + yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); + bool invariantYieldedValue = outsideBbArg || outsideOpResult; + + if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { + // Padding with a constant: Create linalg.fill. + Dialect *arithDialect = + rewriter.getContext()->getLoadedDialect(); + Value fillValue = + arithDialect + ->materializeConstant(rewriter, constYieldedValue, + yieldedValue.getType(), yieldedValue.getLoc()) + ->getResult(0); + auto fillOp = rewriter.create(loc, ValueRange(fillValue), + ValueRange(empty)); + rewriter.setInsertionPointAfter(fillOp); + filled = fillOp.getResult(0); + } else if (invariantYieldedValue) { + // Padding with an invariant value. + auto fillOp = rewriter.create(loc, ValueRange(yieldedValue), + ValueRange(empty)); + rewriter.setInsertionPointAfter(fillOp); + filled = fillOp.getResult(0); + } else { // Create linalg.generic. SmallVector iteratorTypes( - tensorType.getRank(), utils::IteratorType::parallel); + resultType.getRank(), utils::IteratorType::parallel); SmallVector indexingMaps( - 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); + 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); auto genericOp = rewriter.create( - loc, tensorType, /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ + loc, resultType, /*inputs=*/ValueRange(), + /*outputs=*/ValueRange{empty}, /*indexingMaps=*/ indexingMaps, iteratorTypes); Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, - tensorType.getElementType(), loc); + resultType.getElementType(), loc); rewriter.setInsertionPointToStart(body); SmallVector bbArgReplacements; - for (int64_t i = 0; i < tensorType.getRank(); ++i) + for (int64_t i = 0; i < resultType.getRank(); ++i) bbArgReplacements.push_back(rewriter.create(loc, i)); - rewriter.mergeBlocks(&generateOp.getBody().front(), body, - bbArgReplacements); + rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); // Update terminator. auto yieldOp = cast(body->getTerminator()); rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + rewriter.setInsertionPointAfter(genericOp); + filled = genericOp->getResult(0); + } - // Replace tensor.generate. - rewriter.replaceOp(generateOp, genericOp->getResult(0)); + // If the `padOp` has a nofold attribute, explicitly insert a `linalg.copy`. + if (padOp.getNofoldAttr()) { + using bufferization::AllocTensorOp; + Value allocated = + rewriter.create(loc, resultType, dynamicSizes); + auto resultLinalgOp = + rewriter.create(loc, filled, allocated); + filled = resultLinalgOp->getResult(0); + } + + // Create tensor::InsertSliceOp. + SmallVector sliceSizes = + getMixedSizes(rewriter, loc, padOp.getSource()); + SmallVector sliceStrides(resultType.getRank(), + rewriter.getIndexAttr(1)); + auto insertSliceOp = rewriter.replaceOpWithNewOp( + padOp, padOp.getSource(), filled, + /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); + return insertSliceOp.getOperation(); +} + +namespace { +/// Lower tensor.from_elements to a sequence of chained tensor.insert. +struct FromElementsOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FromElementsOp elementsOp, + PatternRewriter &rewriter) const override { + if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, elementsOp))) + return failure(); + return success(); + } +}; + +/// Lower tensor.generate to linalg.generic. +struct GenerateOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenerateOp generateOp, + PatternRewriter &rewriter) const override { + if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, generateOp))) + return failure(); return success(); } }; @@ -141,97 +263,11 @@ LogicalResult matchAndRewrite(PadOp padOp, PatternRewriter &rewriter) const override { - // Only ops with exactly one block are supported. - if (!padOp.getBodyRegion().hasOneBlock()) + if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, padOp))) return failure(); - - // Create tensor.empty. - Location loc = padOp.getLoc(); - RankedTensorType resultType = padOp.getResultType(); - ReifiedRankedShapedTypeDims reifiedShape; - if (failed(cast(padOp.getOperation()) - .reifyResultShapes(rewriter, reifiedShape))) - return rewriter.notifyMatchFailure( - padOp, "failed to reify tensor.pad op result shape"); - SmallVector dynamicSizes; - for (int64_t i = 0; i < resultType.getRank(); ++i) - if (resultType.isDynamicDim(i)) - dynamicSizes.push_back(reifiedShape[0][i]); - auto emptyOp = rewriter.create(loc, resultType, dynamicSizes); - - // Examine the yielded value to decide if a linalg.generic is neede or a - // linalg.fill is sufficient. - Value filled; - Value yieldedValue = - cast(padOp.getBody()->getTerminator()).getValue(); - Attribute constYieldedValue; - // Is the yielded value a bbArg defined outside of the PadOp? - bool outsideBbArg = - yieldedValue.isa() && - yieldedValue.cast().getOwner()->getParentOp() != - padOp.getOperation(); - // Is the yielded value an OpResult defined outside of the PadOp? - bool outsideOpResult = - yieldedValue.isa() && - yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); - bool invariantYieldedValue = outsideBbArg || outsideOpResult; - if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { - // Padding with a constant: Create linalg.fill. - Dialect *arithDialect = - rewriter.getContext()->getLoadedDialect(); - Value fillValue = arithDialect - ->materializeConstant(rewriter, constYieldedValue, - yieldedValue.getType(), - yieldedValue.getLoc()) - ->getResult(0); - auto fillOp = rewriter.create( - loc, ValueRange(fillValue), ValueRange(emptyOp.getResult())); - rewriter.setInsertionPointAfter(fillOp); - filled = fillOp.getResult(0); - } else if (invariantYieldedValue) { - // Padding with an invariant value. - auto fillOp = rewriter.create( - loc, ValueRange(yieldedValue), ValueRange(emptyOp.getResult())); - rewriter.setInsertionPointAfter(fillOp); - filled = fillOp.getResult(0); - } else { - // Create linalg.generic. - SmallVector iteratorTypes( - resultType.getRank(), utils::IteratorType::parallel); - SmallVector indexingMaps( - 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); - auto genericOp = rewriter.create( - loc, resultType, /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ - indexingMaps, iteratorTypes); - Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, - resultType.getElementType(), loc); - rewriter.setInsertionPointToStart(body); - SmallVector bbArgReplacements; - for (int64_t i = 0; i < resultType.getRank(); ++i) - bbArgReplacements.push_back(rewriter.create(loc, i)); - rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); - - // Update terminator. - auto yieldOp = cast(body->getTerminator()); - rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); - rewriter.setInsertionPointAfter(genericOp); - filled = genericOp->getResult(0); - } - - // Create tensor::InsertSliceOp. - SmallVector sliceSizes = - getMixedSizes(rewriter, loc, padOp.getSource()); - SmallVector sliceStrides(resultType.getRank(), - rewriter.getIndexAttr(1)); - rewriter.replaceOpWithNewOp( - padOp, padOp.getSource(), filled, - /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); - return success(); } }; - } // namespace void linalg::populateConvertToDestinationStylePatterns( diff --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir rename from mlir/test/Dialect/Linalg/convert-to-destination-style.mlir rename to mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir --- a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns -canonicalize %s | FileCheck %s +// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file %s | FileCheck %s // CHECK-LABEL: func @tensor_from_elements_0d( // CHECK-SAME: %[[arg0:.*]]: index @@ -10,6 +10,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK-LABEL: func @tensor_from_elements_1d( @@ -46,6 +54,14 @@ return %0 : tensor<3x2xindex> } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)> @@ -70,6 +86,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.generate"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> @@ -103,6 +127,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> @@ -129,6 +161,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> @@ -152,3 +192,45 @@ } : tensor to tensor return %0 : tensor } + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> +// CHECK: #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)> +// CHECK-LABEL: func @tensor_pad_invariant( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index, %[[padding:.*]]: index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]] +// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]] +// CHECK: %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor +// CHECK: %[[filled:.*]] = linalg.fill ins(%[[padding]] : index) outs(%[[empty]] : tensor) +// CHECK: %[[alloc_tensor:.*]] = bufferization.alloc_tensor(%{{.*}}) : tensor +// CHECK: %[[copied:.*]] = linalg.copy ins(%[[filled]] : tensor) outs(%[[alloc_tensor]] : tensor) -> tensor +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK: %[[inserted:.*]] = tensor.insert_slice %[[t1]] into %[[copied]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] : tensor into tensor +// CHECK: return %[[inserted]] +func.func @tensor_pad_invariant(%t1: tensor, %l2: index, %h1: index, + %h2: index, %padding: index) -> tensor { + %0 = tensor.pad %t1 nofold low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %padding : index + } : tensor to tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -128,10 +128,6 @@ *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; - Option testConvertToDestinationStylePatterns{ - *this, "test-convert-to-destination-style-patterns", - llvm::cl::desc("Test patterns that convert ops to destination style"), - llvm::cl::init(false)}; }; } // namespace @@ -222,12 +218,6 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } -static void applyConvertToDestinationStylePatterns(Operation *rootOp) { - RewritePatternSet patterns(rootOp->getContext()); - populateConvertToDestinationStylePatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); -} - /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -254,8 +244,6 @@ return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); - if (testConvertToDestinationStylePatterns) - applyConvertToDestinationStylePatterns(getOperation()); } namespace mlir {