diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -873,6 +873,45 @@ }]; } +//===----------------------------------------------------------------------===// +// RewriteInDestinationPassingStyleOp. +//===----------------------------------------------------------------------===// + +def RewriteInDestinationPassingStyleOp : Op< + Transform_Dialect, "structured.rewrite_in_destination_passing_style", + [MemoryEffectsOpInterface, + NavigationTransformOpTrait, + DeclareOpInterfaceMethods]> { + let description = [{ + Rewrite a supported tensor operation that is not in destination-passing style + into a form that is in destination-passing style. + Currently supported operations are: + - tensor.pad + - tensor.generate + - tensor.from_elements + This dichotomy hints at a future interface, for now the implementation just + switches between different implementation. + + #### Return modes + + This operation ignores non-unsupported ops and drops them from the return. + If all the operations referred to by the `target` PDLOperation generalize + properly, the transform succeeds. Otherwise the transform silently fails. + The return handle points to a subset of successfully produced operations: + - tensor.pad case, the returned handle points to the tensor.insert_slice. + - tensor.generate case, the returned handle points to the linalg.generic. + - tensor.from_elements case, the returned handle points to the last + tensor.insert. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + let assemblyFormat = [{ + $target attr-dict + `:` functional-type($target, results) + }]; +} + //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1206,6 +1206,20 @@ linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef outerPerm, ArrayRef innerPerm); +/// Rewrite tensor.from_elements to linalg.generic. +FailureOr +rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::FromElementsOp fromElementsOp); + +/// Rewrite tensor.generate to linalg.generic. +FailureOr +rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::GenerateOp generateOp); + +/// Rewrite tensor.pad to linalg.generic + tensor.insert_slice. +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::PadOp padOp); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -21,6 +21,10 @@ namespace mlir { +/// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp +/// with attribute with value `0`. +bool isZeroIndex(OpFoldResult v); + /// Represents a range (offset, size, and stride) where each element of the /// triple may be dynamic or static. struct Range { @@ -30,8 +34,8 @@ }; /// Given an array of Range values, return a tuple of (offset vector, sizes -/// vector, and strides vector) formed by separating out the individual elements -/// of each range. +/// vector, and strides vector) formed by separating out the individual +/// elements of each range. std::tuple, SmallVector, SmallVector> getOffsetsSizesAndStrides(ArrayRef ranges); @@ -40,14 +44,15 @@ /// a) it is an IntegerAttr /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. /// In such dynamic cases, ShapedType::kDynamic is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. +/// `staticVec`. This is useful to extract mixed static and dynamic entries +/// that come from an AttrSizedOperandSegments trait. void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl &dynamicVec, SmallVectorImpl &staticVec); -/// Helper function to dispatch multiple OpFoldResults according to the behavior -/// of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single OpFoldResult. +/// Helper function to dispatch multiple OpFoldResults according to the +/// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single +/// OpFoldResult. void dispatchIndexOpFoldResults(ArrayRef ofrs, SmallVectorImpl &dynamicVec, SmallVectorImpl &staticVec); @@ -72,27 +77,28 @@ /// Return true if `ofr` is constant integer equal to `value`. bool isConstantIntValue(OpFoldResult ofr, int64_t value); -/// Return true if ofr1 and ofr2 are the same integer constant attribute values -/// or the same SSA value. -/// Ignore integer bitwitdh and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType have no bitwidth. +/// Return true if ofr1 and ofr2 are the same integer constant attribute +/// values or the same SSA value. Ignore integer bitwitdh and type mismatch +/// that come from the fact there is no IndexAttr and that IndexType have no +/// bitwidth. bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); /// Helper function to convert a vector of `OpFoldResult`s into a vector of -/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold result -/// if it casts to a `Value` or create an index-type constant if it casts to -/// `IntegerAttr`. No other attribute types are supported. +/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold +/// result if it casts to a `Value` or create an index-type constant if it +/// casts to `IntegerAttr`. No other attribute types are supported. SmallVector getAsValues(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec); -/// Return a vector of OpFoldResults with the same size a staticValues, but all -/// elements for which ShapedType::isDynamic is true, will be replaced by +/// Return a vector of OpFoldResults with the same size a staticValues, but +/// all elements for which ShapedType::isDynamic is true, will be replaced by /// dynamicValues. SmallVector getMixedValues(ArrayRef staticValues, ValueRange dynamicValues, Builder &b); -/// Decompose a vector of mixed static or dynamic values into the corresponding -/// pair of arrays. This is the inverse function of `getMixedValues`. +/// Decompose a vector of mixed static or dynamic values into the +/// corresponding pair of arrays. This is the inverse function of +/// `getMixedValues`. std::pair> decomposeMixedValues(Builder &b, const SmallVectorImpl &mixedValues); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -39,6 +39,7 @@ #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using namespace mlir; @@ -1919,6 +1920,32 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// RewriteInDestinationPassingStyleOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::RewriteInDestinationPassingStyleOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + SmallVector res; + ArrayRef targetOps = state.getPayloadOps(getTarget()); + for (Operation *target : targetOps) { + IRRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + FailureOr maybeResult = + TypeSwitch>(target) + .Case( + [&rewriter](auto op) { + return rewriteInDestinationPassingStyle(rewriter, op); + }); + if (failed(maybeResult)) + return emitDefaultSilenceableFailure(target); + res.push_back(*maybeResult); + } + results.set(getResult().cast(), res); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -19,8 +19,10 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" using namespace mlir; @@ -50,94 +52,6 @@ return destination; } -namespace { - -/// Lower tensor.from_elements to a sequence of chained tensor.insert. -struct FromElementsOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(FromElementsOp elementsOp, - PatternRewriter &rewriter) const override { - Location loc = elementsOp.getLoc(); - RankedTensorType tensorType = elementsOp.getType().cast(); - auto shape = tensorType.getShape(); - - // Create tensor.empty. - auto emptyOp = rewriter.create(loc, tensorType, ValueRange()); - - // Case: tensor. - if (shape.empty()) { - rewriter.replaceOpWithNewOp( - elementsOp, elementsOp.getElements().front(), emptyOp.getResult(), - ValueRange()); - return success(); - } - - // Create constants for the range of possible indices [0, max{shape_i}). - auto maxDim = *std::max_element(shape.begin(), shape.end()); - SmallVector constants; - constants.reserve(maxDim); - for (int i = 0; i < maxDim; ++i) - constants.push_back(rewriter.create(loc, i)); - - // Traverse all elements and create tensor.insert ops. - auto elementIt = elementsOp.getElements().begin(); - SmallVector indices(tensorType.getRank(), constants[0]); - Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(), - shape, constants, elementIt, indices); - - // Replace tensor.from_elements. - rewriter.replaceOp(elementsOp, result); - return success(); - } -}; - -/// Lower tensor.generate to linalg.generic. -struct GenerateOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GenerateOp generateOp, - PatternRewriter &rewriter) const override { - // Only ops with exactly one block are supported. - if (!generateOp.getBody().hasOneBlock()) - return failure(); - - Location loc = generateOp.getLoc(); - RankedTensorType tensorType = generateOp.getType().cast(); - - // Create tensor.empty. - auto emptyOp = rewriter.create(loc, tensorType, - generateOp.getDynamicExtents()); - - // Create linalg.generic. - SmallVector iteratorTypes( - tensorType.getRank(), utils::IteratorType::parallel); - SmallVector indexingMaps( - 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); - auto genericOp = rewriter.create( - loc, tensorType, /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ - indexingMaps, iteratorTypes); - Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, - tensorType.getElementType(), loc); - rewriter.setInsertionPointToStart(body); - SmallVector bbArgReplacements; - for (int64_t i = 0; i < tensorType.getRank(); ++i) - bbArgReplacements.push_back(rewriter.create(loc, i)); - rewriter.mergeBlocks(&generateOp.getBody().front(), body, - bbArgReplacements); - - // Update terminator. - auto yieldOp = cast(body->getTerminator()); - rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); - - // Replace tensor.generate. - rewriter.replaceOp(generateOp, genericOp->getResult(0)); - return success(); - } -}; -} // namespace - static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, Location loc, PadOp padOp, Value dest) { @@ -287,49 +201,133 @@ return toTensorOp; } -namespace { -/// Lower tensor.pad to linalg.generic + tensor.insert_slice. -struct PadOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// Lower tensor.from_elements to a sequence of chained tensor.insert. +FailureOr mlir::linalg::rewriteInDestinationPassingStyle( + RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) { + Location loc = fromElementsOp.getLoc(); + RankedTensorType tensorType = + fromElementsOp.getType().cast(); + auto shape = tensorType.getShape(); + + // Create tensor.empty. + auto emptyOp = rewriter.create(loc, tensorType, ValueRange()); + + // Case: tensor. + if (shape.empty()) { + Operation *res = rewriter.replaceOpWithNewOp( + fromElementsOp, fromElementsOp.getElements().front(), + emptyOp.getResult(), ValueRange()); + return res; + } - LogicalResult matchAndRewrite(PadOp padOp, - PatternRewriter &rewriter) const override { - // Only ops with exactly one block are supported. - if (!padOp.getBodyRegion().hasOneBlock()) - return failure(); + // Create constants for the range of possible indices [0, max{shape_i}). + auto maxDim = *std::max_element(shape.begin(), shape.end()); + SmallVector constants; + constants.reserve(maxDim); + for (int i = 0; i < maxDim; ++i) + constants.push_back(rewriter.create(loc, i)); + + // Traverse all elements and create tensor.insert ops. + auto elementIt = fromElementsOp.getElements().begin(); + SmallVector indices(tensorType.getRank(), constants[0]); + Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(), + shape, constants, elementIt, indices); + + // Replace tensor.from_elements. + rewriter.replaceOp(fromElementsOp, result); + return result.getDefiningOp(); +} - // Create tensor.empty. - Location loc = padOp.getLoc(); - RankedTensorType resultType = padOp.getResultType(); - ReifiedRankedShapedTypeDims reifiedShape; - if (failed(cast(padOp.getOperation()) - .reifyResultShapes(rewriter, reifiedShape))) - return rewriter.notifyMatchFailure( - padOp, "failed to reify tensor.pad op result shape"); - SmallVector dynamicSizes; - for (int64_t i = 0; i < resultType.getRank(); ++i) - if (resultType.isDynamicDim(i)) - dynamicSizes.push_back(reifiedShape[0][i]); - auto emptyOp = rewriter.create(loc, resultType, dynamicSizes); - - // Create linalg.fill or linalg.generic. - Operation *fillOp = - movePaddingToFillOrGenericOp(rewriter, loc, padOp, emptyOp.getResult()); - rewriter.setInsertionPointAfter(fillOp); - - // Create tensor::InsertSliceOp. - SmallVector sliceSizes = - getMixedSizes(rewriter, loc, padOp.getSource()); - SmallVector sliceStrides(resultType.getRank(), - rewriter.getIndexAttr(1)); - rewriter.replaceOpWithNewOp( - padOp, padOp.getSource(), fillOp->getResult(0), - /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); +/// Lower tensor.generate to linalg.generic. +FailureOr +mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::GenerateOp generateOp) { + // Only ops with exactly one block are supported. + if (!generateOp.getBody().hasOneBlock()) + return failure(); - return success(); + Location loc = generateOp.getLoc(); + RankedTensorType tensorType = generateOp.getType().cast(); + + // Create tensor.empty. + auto emptyOp = + rewriter.create(loc, tensorType, generateOp.getDynamicExtents()); + + // Create linalg.generic. + SmallVector iteratorTypes(tensorType.getRank(), + utils::IteratorType::parallel); + SmallVector indexingMaps( + 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); + auto genericOp = rewriter.create( + loc, tensorType, /*inputs=*/ValueRange(), + /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ + indexingMaps, iteratorTypes); + Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, + tensorType.getElementType(), loc); + rewriter.setInsertionPointToStart(body); + SmallVector bbArgReplacements; + for (int64_t i = 0; i < tensorType.getRank(); ++i) + bbArgReplacements.push_back(rewriter.create(loc, i)); + rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements); + + // Update terminator. + auto yieldOp = cast(body->getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + + // Replace tensor.generate. + rewriter.replaceOp(generateOp, genericOp->getResult(0)); + return genericOp.getOperation(); +} + +/// Lower tensor.pad to linalg.generic + tensor.insert_slice. +FailureOr +mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::PadOp padOp) { + // Only ops with exactly one block are supported. + if (!padOp.getBodyRegion().hasOneBlock()) + return failure(); + + // Create tensor.empty. + Location loc = padOp.getLoc(); + RankedTensorType resultType = padOp.getResultType(); + ReifiedRankedShapedTypeDims reifiedShape; + if (failed(cast(padOp.getOperation()) + .reifyResultShapes(rewriter, reifiedShape))) + return rewriter.notifyMatchFailure( + padOp, "failed to reify tensor.pad op result shape"); + SmallVector dynamicSizes; + for (int64_t i = 0; i < resultType.getRank(); ++i) + if (resultType.isDynamicDim(i)) + dynamicSizes.push_back(reifiedShape[0][i]); + + // If the `padOp` has a nofold attribute and all paddings are known to be 0, + // explicitly insert a `linalg.copy`. + if (padOp.getNofoldAttr() && + llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) && + llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) { + using bufferization::AllocTensorOp; + Value allocated = + rewriter.create(loc, resultType, dynamicSizes); + auto copyOp = rewriter.replaceOpWithNewOp( + padOp, padOp.getSource(), allocated); + return copyOp.getOperation(); } -}; -} // namespace + + Value empty = rewriter.create(loc, resultType, dynamicSizes); + // Create linalg.fill or linalg.generic. + Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty); + rewriter.setInsertionPointAfter(fillOp); + + // Create tensor::InsertSliceOp. + SmallVector sliceSizes = + getMixedSizes(rewriter, loc, padOp.getSource()); + SmallVector sliceStrides(resultType.getRank(), + rewriter.getIndexAttr(1)); + auto insertSliceOp = rewriter.replaceOpWithNewOp( + padOp, padOp.getSource(), fillOp->getResult(0), + /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); + return insertSliceOp.getOperation(); +} Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value, Attribute memorySpace) { @@ -368,6 +366,45 @@ return toTensorOp; } +namespace { +/// Lower tensor.from_elements to a sequence of chained tensor.insert. +struct FromElementsOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FromElementsOp fromElementsOp, + PatternRewriter &rewriter) const override { + if (failed( + linalg::rewriteInDestinationPassingStyle(rewriter, fromElementsOp))) + return failure(); + return success(); + } +}; + +/// Lower tensor.generate to linalg.generic. +struct GenerateOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenerateOp generateOp, + PatternRewriter &rewriter) const override { + if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, generateOp))) + return failure(); + return success(); + } +}; + +/// Lower tensor.pad to linalg.generic + tensor.insert_slice. +struct PadOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadOp padOp, + PatternRewriter &rewriter) const override { + if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, padOp))) + return failure(); + return success(); + } +}; +} // namespace + void linalg::populateConvertToDestinationStylePatterns( RewritePatternSet &patterns) { patterns.insert( diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -44,18 +44,6 @@ using namespace mlir::linalg; using namespace mlir::scf; -static bool isZero(OpFoldResult v) { - if (!v) - return false; - if (auto attr = v.dyn_cast()) { - IntegerAttr intAttr = attr.dyn_cast(); - return intAttr && intAttr.getValue().isZero(); - } - if (auto cst = v.get().getDefiningOp()) - return cst.value() == 0; - return false; -} - namespace { // Helper visitor to determine whether an AffineExpr is tiled. @@ -70,7 +58,7 @@ TileCheck(ArrayRef tileSizes) : tileSizes(tileSizes) {} void visitDimExpr(AffineDimExpr expr) { - isTiled |= !isZero(tileSizes[expr.getPosition()]); + isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]); } void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { visit(expr.getLHS()); @@ -869,7 +857,7 @@ SmallVector offsets; for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n"); - bool isTiled = !isZero(tileSizes[idx]); + bool isTiled = !isZeroIndex(tileSizes[idx]); offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); LLVM_DEBUG(llvm::dbgs() << "computeTileOffsets: " << offsets.back() << "\n"); @@ -882,7 +870,7 @@ ArrayRef sizeBounds) { SmallVector sizes; for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { - bool isTiled = !isZero(tileSizes[idx]); + bool isTiled = !isZeroIndex(tileSizes[idx]); // Before composing, we need to make range a closed interval. OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; AffineExpr d0 = getAffineDimExpr(0, b.getContext()); @@ -938,7 +926,7 @@ bool omitPartialTileCheck) { assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), - [](OpFoldResult v) { return !isZero(v); })) && + [](OpFoldResult v) { return !isZeroIndex(v); })) && "expected as many ivs as non-zero sizes"); // Construct (potentially temporary) mins and maxes on which to apply maps diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -14,6 +14,18 @@ namespace mlir { +bool isZeroIndex(OpFoldResult v) { + if (!v) + return false; + if (auto attr = v.dyn_cast()) { + IntegerAttr intAttr = attr.dyn_cast(); + return intAttr && intAttr.getValue().isZero(); + } + if (auto cst = v.get().getDefiningOp()) + return cst.value() == 0; + return false; +} + std::tuple, SmallVector, SmallVector> getOffsetsSizesAndStrides(ArrayRef ranges) { diff --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir rename from mlir/test/Dialect/Linalg/convert-to-destination-style.mlir rename to mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir --- a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns -canonicalize %s | FileCheck %s +// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file -canonicalize %s | FileCheck %s // CHECK-LABEL: func @tensor_from_elements_0d( // CHECK-SAME: %[[arg0:.*]]: index @@ -10,6 +10,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK-LABEL: func @tensor_from_elements_1d( @@ -25,6 +33,14 @@ return %0 : tensor<2xindex> } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK-LABEL: func @tensor_from_elements_2d( @@ -46,6 +62,14 @@ return %0 : tensor<3x2xindex> } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)> @@ -70,6 +94,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.generate"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> @@ -103,6 +135,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> @@ -129,6 +169,14 @@ return %0 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> @@ -152,3 +200,39 @@ } : tensor to tensor return %0 : tensor } + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// CHECK-LABEL: func @tensor_pad_nofold( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[padding:.*]]: index +// CHECK-NOT: linalg.fill +// CHECK-NOT: generic +// CHECK-NOT: insert_slice +// CHECK: %[[alloc_tensor:.*]] = bufferization.alloc_tensor(%{{.*}}) : tensor +// CHECK: %[[copied:.*]] = linalg.copy ins(%[[t1]] : tensor) outs(%[[alloc_tensor]] : tensor) -> tensor +// CHECK: return %[[copied]] +func.func @tensor_pad_nofold(%t1: tensor, %padding: index) + -> tensor { + %c0 = arith.constant 0 : index + %0 = tensor.pad %t1 nofold low[0, %c0] high[%c0, 0] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %padding : index + } : tensor to tensor + return %0: tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.rewrite_in_destination_passing_style %0 + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -128,10 +128,6 @@ *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; - Option testConvertToDestinationStylePatterns{ - *this, "test-convert-to-destination-style-patterns", - llvm::cl::desc("Test patterns that convert ops to destination style"), - llvm::cl::init(false)}; }; } // namespace @@ -222,12 +218,6 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } -static void applyConvertToDestinationStylePatterns(Operation *rootOp) { - RewritePatternSet patterns(rootOp->getContext()); - populateConvertToDestinationStylePatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); -} - /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -254,8 +244,6 @@ return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); - if (testConvertToDestinationStylePatterns) - applyConvertToDestinationStylePatterns(getOperation()); } namespace mlir {