diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -873,6 +873,45 @@ }]; } +//===----------------------------------------------------------------------===// +// RewriteInDestinationPassingStyleOp. +//===----------------------------------------------------------------------===// + +def RewriteInDestinationPassingStyleOp : Op< + Transform_Dialect, "structured.rewrite_in_destination_passing_style", + [MemoryEffectsOpInterface, + NavigationTransformOpTrait, + DeclareOpInterfaceMethods]> { + let description = [{ + Rewrite a supported tensor operation that is not in destination-passing style + into a form that is in destination-passing style. + Currently supported operations are: + - tensor.pad + - tensor.generate + - tensor.from_elements + This dichotomy hints at a future interface, for now the implementation just + switches between different implementation. + + #### Return modes + + This operation ignores non-unsupported ops and drops them from the return. + If all the operations referred to by the `target` PDLOperation generalize + properly, the transform succeeds. Otherwise the transform silently fails. + The return handle points to a subset of successfully produced operations: + - tensor.pad case, the returned handle points to the tensor.insert_slice. + - tensor.generate case, the returned handle points to the linalg.generic. + - tensor.from_elements case, the returned handle points to the last + tensor.insert. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + let assemblyFormat = [{ + $target attr-dict + `:` functional-type($target, results) + }]; +} + //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1206,6 +1206,20 @@ linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef outerPerm, ArrayRef innerPerm); +/// Rewrite tensor.from_elements to linalg.generic. +FailureOr +rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::FromElementsOp fromElementsOp); + +/// Rewrite tensor.generate to linalg.generic. +FailureOr +rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::GenerateOp generateOp); + +/// Rewrite tensor.pad to linalg.generic + tensor.insert_slice. +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::PadOp padOp); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -39,6 +39,7 @@ #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using namespace mlir; @@ -57,8 +58,8 @@ template static FailureOr tryApply(Operation *operation, Args &&...args) { // Check if the given operation has the type expected by the pattern. - using OpTy = typename llvm::function_traits< - decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; + using OpTy = typename llvm::function_traits::template arg_t<0>; auto op = dyn_cast(operation); if (!op) return failure(); @@ -1919,6 +1920,32 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// RewriteInDestinationPassingStyleOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::RewriteInDestinationPassingStyleOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + SmallVector res; + ArrayRef targetOps = state.getPayloadOps(getTarget()); + for (Operation *target : targetOps) { + IRRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + FailureOr maybeResult = + TypeSwitch>(target) + .Case( + [&rewriter](auto op) { + return rewriteInDestinationPassingStyle(rewriter, op); + }); + if (failed(maybeResult)) + return emitDefaultSilenceableFailure(target); + res.push_back(*maybeResult); + } + results.set(getResult().cast(), res); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -50,94 +50,6 @@ return destination; } -namespace { - -/// Lower tensor.from_elements to a sequence of chained tensor.insert. -struct FromElementsOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(FromElementsOp elementsOp, - PatternRewriter &rewriter) const override { - Location loc = elementsOp.getLoc(); - RankedTensorType tensorType = elementsOp.getType().cast(); - auto shape = tensorType.getShape(); - - // Create tensor.empty. - auto emptyOp = rewriter.create(loc, tensorType, ValueRange()); - - // Case: tensor. - if (shape.empty()) { - rewriter.replaceOpWithNewOp( - elementsOp, elementsOp.getElements().front(), emptyOp.getResult(), - ValueRange()); - return success(); - } - - // Create constants for the range of possible indices [0, max{shape_i}). - auto maxDim = *std::max_element(shape.begin(), shape.end()); - SmallVector constants; - constants.reserve(maxDim); - for (int i = 0; i < maxDim; ++i) - constants.push_back(rewriter.create(loc, i)); - - // Traverse all elements and create tensor.insert ops. - auto elementIt = elementsOp.getElements().begin(); - SmallVector indices(tensorType.getRank(), constants[0]); - Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(), - shape, constants, elementIt, indices); - - // Replace tensor.from_elements. - rewriter.replaceOp(elementsOp, result); - return success(); - } -}; - -/// Lower tensor.generate to linalg.generic. -struct GenerateOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GenerateOp generateOp, - PatternRewriter &rewriter) const override { - // Only ops with exactly one block are supported. - if (!generateOp.getBody().hasOneBlock()) - return failure(); - - Location loc = generateOp.getLoc(); - RankedTensorType tensorType = generateOp.getType().cast(); - - // Create tensor.empty. - auto emptyOp = rewriter.create(loc, tensorType, - generateOp.getDynamicExtents()); - - // Create linalg.generic. - SmallVector iteratorTypes( - tensorType.getRank(), utils::IteratorType::parallel); - SmallVector indexingMaps( - 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); - auto genericOp = rewriter.create( - loc, tensorType, /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ - indexingMaps, iteratorTypes); - Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, - tensorType.getElementType(), loc); - rewriter.setInsertionPointToStart(body); - SmallVector bbArgReplacements; - for (int64_t i = 0; i < tensorType.getRank(); ++i) - bbArgReplacements.push_back(rewriter.create(loc, i)); - rewriter.mergeBlocks(&generateOp.getBody().front(), body, - bbArgReplacements); - - // Update terminator. - auto yieldOp = cast(body->getTerminator()); - rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); - - // Replace tensor.generate. - rewriter.replaceOp(generateOp, genericOp->getResult(0)); - return success(); - } -}; -} // namespace - static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, Location loc, PadOp padOp, Value dest) { @@ -287,49 +199,122 @@ return toTensorOp; } -namespace { +/// Lower tensor.from_elements to a sequence of chained tensor.insert. +FailureOr mlir::linalg::rewriteInDestinationPassingStyle( + RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) { + Location loc = fromElementsOp.getLoc(); + RankedTensorType tensorType = + fromElementsOp.getType().cast(); + auto shape = tensorType.getShape(); + + // Create tensor.empty. + auto emptyOp = rewriter.create(loc, tensorType, ValueRange()); + + // Case: tensor. + if (shape.empty()) { + rewriter.replaceOpWithNewOp( + fromElementsOp, fromElementsOp.getElements().front(), + emptyOp.getResult(), ValueRange()); + return success(); + } + + // Create constants for the range of possible indices [0, max{shape_i}). + auto maxDim = *std::max_element(shape.begin(), shape.end()); + SmallVector constants; + constants.reserve(maxDim); + for (int i = 0; i < maxDim; ++i) + constants.push_back(rewriter.create(loc, i)); + + // Traverse all elements and create tensor.insert ops. + auto elementIt = fromElementsOp.getElements().begin(); + SmallVector indices(tensorType.getRank(), constants[0]); + Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(), + shape, constants, elementIt, indices); + + // Replace tensor.from_elements. + rewriter.replaceOp(fromElementsOp, result); + return success(); +} + +/// Lower tensor.generate to linalg.generic. +FailureOr +mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::GenerateOp generateOp) { + // Only ops with exactly one block are supported. + if (!generateOp.getBody().hasOneBlock()) + return failure(); + + Location loc = generateOp.getLoc(); + RankedTensorType tensorType = generateOp.getType().cast(); + + // Create tensor.empty. + auto emptyOp = + rewriter.create(loc, tensorType, generateOp.getDynamicExtents()); + + // Create linalg.generic. + SmallVector iteratorTypes(tensorType.getRank(), + utils::IteratorType::parallel); + SmallVector indexingMaps( + 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); + auto genericOp = rewriter.create( + loc, tensorType, /*inputs=*/ValueRange(), + /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ + indexingMaps, iteratorTypes); + Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, + tensorType.getElementType(), loc); + rewriter.setInsertionPointToStart(body); + SmallVector bbArgReplacements; + for (int64_t i = 0; i < tensorType.getRank(); ++i) + bbArgReplacements.push_back(rewriter.create(loc, i)); + rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements); + + // Update terminator. + auto yieldOp = cast(body->getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + + // Replace tensor.generate. + rewriter.replaceOp(generateOp, genericOp->getResult(0)); + return success(); +} + /// Lower tensor.pad to linalg.generic + tensor.insert_slice. -struct PadOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +FailureOr +mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::PadOp padOp) { + // Only ops with exactly one block are supported. + if (!padOp.getBodyRegion().hasOneBlock()) + return failure(); + + // Create tensor.empty. + Location loc = padOp.getLoc(); + RankedTensorType resultType = padOp.getResultType(); + ReifiedRankedShapedTypeDims reifiedShape; + if (failed(cast(padOp.getOperation()) + .reifyResultShapes(rewriter, reifiedShape))) + return rewriter.notifyMatchFailure( + padOp, "failed to reify tensor.pad op result shape"); + SmallVector dynamicSizes; + for (int64_t i = 0; i < resultType.getRank(); ++i) + if (resultType.isDynamicDim(i)) + dynamicSizes.push_back(reifiedShape[0][i]); + auto emptyOp = rewriter.create(loc, resultType, dynamicSizes); - LogicalResult matchAndRewrite(PadOp padOp, - PatternRewriter &rewriter) const override { - // Only ops with exactly one block are supported. - if (!padOp.getBodyRegion().hasOneBlock()) - return failure(); + // Create linalg.fill or linalg.generic. + Operation *fillOp = + movePaddingToFillOrGenericOp(rewriter, loc, padOp, emptyOp.getResult()); + rewriter.setInsertionPointAfter(fillOp); - // Create tensor.empty. - Location loc = padOp.getLoc(); - RankedTensorType resultType = padOp.getResultType(); - ReifiedRankedShapedTypeDims reifiedShape; - if (failed(cast(padOp.getOperation()) - .reifyResultShapes(rewriter, reifiedShape))) - return rewriter.notifyMatchFailure( - padOp, "failed to reify tensor.pad op result shape"); - SmallVector dynamicSizes; - for (int64_t i = 0; i < resultType.getRank(); ++i) - if (resultType.isDynamicDim(i)) - dynamicSizes.push_back(reifiedShape[0][i]); - auto emptyOp = rewriter.create(loc, resultType, dynamicSizes); - - // Create linalg.fill or linalg.generic. - Operation *fillOp = - movePaddingToFillOrGenericOp(rewriter, loc, padOp, emptyOp.getResult()); - rewriter.setInsertionPointAfter(fillOp); - - // Create tensor::InsertSliceOp. - SmallVector sliceSizes = - getMixedSizes(rewriter, loc, padOp.getSource()); - SmallVector sliceStrides(resultType.getRank(), - rewriter.getIndexAttr(1)); - rewriter.replaceOpWithNewOp( - padOp, padOp.getSource(), fillOp->getResult(0), - /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); + // Create tensor::InsertSliceOp. + SmallVector sliceSizes = + getMixedSizes(rewriter, loc, padOp.getSource()); + SmallVector sliceStrides(resultType.getRank(), + rewriter.getIndexAttr(1)); + rewriter.replaceOpWithNewOp( + padOp, padOp.getSource(), fillOp->getResult(0), + /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); - return success(); - } -}; -} // namespace + return success(); +} Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value, Attribute memorySpace) { @@ -368,6 +353,45 @@ return toTensorOp; } +namespace { +/// Lower tensor.from_elements to a sequence of chained tensor.insert. +struct FromElementsOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FromElementsOp fromElementsOp, + PatternRewriter &rewriter) const override { + if (failed( + linalg::rewriteInDestinationPassingStyle(rewriter, fromElementsOp))) + return failure(); + return success(); + } +}; + +/// Lower tensor.generate to linalg.generic. +struct GenerateOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenerateOp generateOp, + PatternRewriter &rewriter) const override { + if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, generateOp))) + return failure(); + return success(); + } +}; + +/// Lower tensor.pad to linalg.generic + tensor.insert_slice. +struct PadOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadOp padOp, + PatternRewriter &rewriter) const override { + if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, padOp))) + return failure(); + return success(); + } +}; +} // namespace + void linalg::populateConvertToDestinationStylePatterns( RewritePatternSet &patterns) { patterns.insert( diff --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir rename from mlir/test/Dialect/Linalg/convert-to-destination-style.mlir rename to mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir --- a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns -canonicalize %s | FileCheck %s +// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file %s | FileCheck %s // CHECK-LABEL: func @tensor_from_elements_0d( // CHECK-SAME: %[[arg0:.*]]: index @@ -10,6 +10,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK-LABEL: func @tensor_from_elements_1d( @@ -46,6 +54,14 @@ return %0 : tensor<3x2xindex> } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)> @@ -70,6 +86,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.generate"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> @@ -103,6 +127,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> @@ -129,6 +161,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> @@ -152,3 +192,45 @@ } : tensor to tensor return %0 : tensor } + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> +// CHECK: #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)> +// CHECK-LABEL: func @tensor_pad_invariant( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index, %[[padding:.*]]: index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]] +// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]] +// CHECK: %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor +// CHECK: %[[filled:.*]] = linalg.fill ins(%[[padding]] : index) outs(%[[empty]] : tensor) +// CHECK: %[[alloc_tensor:.*]] = bufferization.alloc_tensor(%{{.*}}) : tensor +// CHECK: %[[copied:.*]] = linalg.copy ins(%[[filled]] : tensor) outs(%[[alloc_tensor]] : tensor) -> tensor +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK: %[[inserted:.*]] = tensor.insert_slice %[[t1]] into %[[copied]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] : tensor into tensor +// CHECK: return %[[inserted]] +func.func @tensor_pad_invariant(%t1: tensor, %l2: index, %h1: index, + %h2: index, %padding: index) -> tensor { + %0 = tensor.pad %t1 nofold low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %padding : index + } : tensor to tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -128,10 +128,6 @@ *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; - Option testConvertToDestinationStylePatterns{ - *this, "test-convert-to-destination-style-patterns", - llvm::cl::desc("Test patterns that convert ops to destination style"), - llvm::cl::init(false)}; }; } // namespace @@ -222,12 +218,6 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } -static void applyConvertToDestinationStylePatterns(Operation *rootOp) { - RewritePatternSet patterns(rootOp->getContext()); - populateConvertToDestinationStylePatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); -} - /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -254,8 +244,6 @@ return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); - if (testConvertToDestinationStylePatterns) - applyConvertToDestinationStylePatterns(getOperation()); } namespace mlir {