diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -69,12 +69,14 @@ /// {4, 3, 6, 2, 1, 5, 8, 7, 9} /// void getForwardSlice(Operation *op, SetVector *forwardSlice, - TransitiveFilter filter = nullptr /* pass-through*/); + TransitiveFilter filter = nullptr /* pass-through*/, + bool inclusive = false); /// Value-rooted version of `getForwardSlice`. Return the union of all forward /// slices for the uses of the value `root`. void getForwardSlice(Value root, SetVector *forwardSlice, - TransitiveFilter filter = nullptr /* pass-through*/); + TransitiveFilter filter = nullptr /* pass-through*/, + bool inclusive = false); /// Fills `backwardSlice` with the computed backward slice (i.e. /// all the transitive defs of op), **without** including that operation. @@ -111,12 +113,14 @@ /// {1, 2, 5, 3, 4, 6} /// void getBackwardSlice(Operation *op, SetVector *backwardSlice, - TransitiveFilter filter = nullptr /* pass-through*/); + TransitiveFilter filter = nullptr /* pass-through*/, + bool inclusive = false); /// Value-rooted version of `getBackwardSlice`. Return the union of all backward /// slices for the op defining or owning the value `root`. void getBackwardSlice(Value root, SetVector *backwardSlice, - TransitiveFilter filter = nullptr /* pass-through*/); + TransitiveFilter filter = nullptr /* pass-through*/, + bool inclusive = false); /// Iteratively computes backward slices and forward slices until /// a fixed point is reached. Returns an `SetVector` which @@ -198,7 +202,8 @@ SetVector getSlice(Operation *op, TransitiveFilter backwardFilter = nullptr /* pass-through*/, - TransitiveFilter forwardFilter = nullptr /* pass-through*/); + TransitiveFilter forwardFilter = nullptr /* pass-through*/, + bool inclusive = false); /// Multi-root DAG topological sort. /// Performs a topological sort of the Operation in the `toSort` SetVector. diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -27,6 +27,7 @@ namespace tensor { class PackOp; +class PadOp; class UnPackOp; } // namespace tensor diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -744,7 +744,6 @@ DefaultValuedAttr:$padding_values, DefaultValuedAttr:$padding_dimensions, DefaultValuedAttr:$pack_paddings, - DefaultValuedAttr:$hoist_paddings, DefaultValuedAttr< TypedArrayAttrBase, "{}">:$transpose_paddings); @@ -761,6 +760,58 @@ }]; } +//===----------------------------------------------------------------------===// +// HoistPadOp +//===----------------------------------------------------------------------===// + +def HoistPadOp : Op { + let description = [{ + Hoist the tensor.pad target operation by at most the given number of loops. + Optionally apply the transpose attribute to the inner dimensions. + + TODO: In the future, we should consider rewriting as a tensor.pack after + hoisting since this abstraction is now available. + TODO: Maybe also return the linalg.generic transpose created at some point. + + #### Return modes + + This operation ignores non-tensor.pad ops and drops them in the result. + If any non-tensor.pad is passed, the transform emits a silenceable failure. + + If all the operations referred to by the `target` handle padproperly, the + transform succeeds. Otherwise the transform silently fails. + + The return handle points to only the subset of successfully hoisted + tensor.pad operations, which can be empty. + }]; + + // Also allow any !pdl.operation for simpler composition. Non-tensor.pad ops + // will be dropped from the results. + let arguments = + (ins TransformHandleTypeInterface:$target, + I64Attr:$num_loops, + DefaultValuedAttr:$transpose); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = [{ + $target + `by` $num_loops `loops` + (`,` `transpose` `by` $transpose^)? + attr-dict + `:` functional-type(operands, results) + }]; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::tensor::PadOp, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + //===----------------------------------------------------------------------===// // PromoteOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h b/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h +++ /dev/null @@ -1,77 +0,0 @@ -//===- HoistPadding.h - Hoisting for tensor::PadOp -*- C++ --------------*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_HOISTPADDING_H -#define MLIR_DIALECT_LINALG_TRANSFORMS_HOISTPADDING_H - -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -class Value; - -namespace tensor { -class PadOp; -} // namespace tensor - -namespace linalg { -class GenericOp; - -/// Mechanically hoist padding operations on tensors by `numLoops` into a new, -/// generally larger tensor. This achieves packing of multiple padding ops into -/// a larger tensor. On success, `opToHoist` is replaced by the cloned version -/// in the packing loop so the caller can continue reasoning about the padding -/// operation. If `transposeVector` is non-empty, hoist padding introduces a -/// GenericOp to transpose the padded tensor before inserting it into the packed -/// tensor. A `transposeVector` can change the storage order of the padded -/// tensor but does not change the order of the pack or compute loops. -/// -/// -/// Example in pseudo-mlir: -/// ======================= -/// -/// If hoistPaddingOnTensors is called with `nLoops` = 2 on the following IR. -/// ``` -/// scf.for (%i, %j, %k) -/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor -/// %0 = tensor.pad %st0 low[0, 0] high[...] { -/// ^bb0( ... ): -/// linalg.yield %pad -/// } : tensor to tensor<4x8xf32> -/// compute(%0) -/// ``` -/// -/// IR resembling the following is produced: -/// -/// ``` -/// scf.for (%i) { -/// %packed_init = tensor.empty range(%j) : tensor -/// %packed = scf.for (%k) iter_args(%p : %packed_init) { -/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor -/// %0 = tensor.pad %st0 low[0, 0] high[...] { -/// ^bb0( ... ): -/// linalg.yield %pad -/// } : tensor to tensor<4x8xf32> -/// %1 = tensor.insert_slice %0 ... -/// : tensor<4x8xf32> to tensor -/// scf.yield %1: tensor -/// } -> tensor -/// scf.for (%j, %k) { -/// %st0 = tensor.extract_slice %packed [%k, 0, 0][1, 4, 8][1, 1, 1] : -/// tensor to tensor<4x8xf32> -/// compute(%st0) -/// } -/// } -/// ``` -FailureOr hoistPaddingOnTensors( - tensor::PadOp opToHoist, int numLoops, ArrayRef transposeVector, - tensor::PadOp &hoistedOp, SmallVectorImpl &transposeOps); - -} // namespace linalg -} // namespace mlir - -#endif // MLIR_DIALECT_LINALG_TRANSFORMS_HOISTPADDING_H diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -357,14 +357,70 @@ /// shaped `paddingDimensions` and return the extracted dynamically shaped /// results. If padding fails, return failure. FailureOr> -rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, +rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, ArrayRef paddingDimensions, ArrayRef paddingValues, ArrayRef packPaddings, LinalgOp &paddedOp); -/// Apply padding to `linalgOp` -FailureOr padLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, - LinalgPaddingOptions options); +/// Mechanically hoist padding operations on tensors by `numLoops` into a new, +/// generally larger tensor. This achieves packing of multiple padding ops into +/// a larger tensor. On success, `opToHoist` is replaced by the cloned version +/// in the packing loop so the caller can continue reasoning about the padding +/// operation. If `transposeVector` is non-empty, hoist padding introduces a +/// GenericOp to transpose the padded tensor before inserting it into the packed +/// tensor. A `transposeVector` can change the storage order of the padded +/// tensor but does not change the order of the pack or compute loops. +/// +/// TODO: In the future, we should consider rewriting as a tensor.pack after +/// hoisting since this abstraction is now available. +/// +/// Example in pseudo-mlir: +/// ======================= +/// +/// If hoistPaddingOnTensors is called with `nLoops` = 2 on the following IR. +/// ``` +/// scf.for (%i, %j, %k) +/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor +/// %0 = tensor.pad %st0 low[0, 0] high[...] { +/// ^bb0( ... ): +/// linalg.yield %pad +/// } : tensor to tensor<4x8xf32> +/// compute(%0) +/// ``` +/// +/// IR resembling the following is produced: +/// +/// ``` +/// scf.for (%i) { +/// %packed_init = tensor.empty range(%j) : tensor +/// %packed = scf.for (%k) iter_args(%p : %packed_init) { +/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor +/// %0 = tensor.pad %st0 low[0, 0] high[...] { +/// ^bb0( ... ): +/// linalg.yield %pad +/// } : tensor to tensor<4x8xf32> +/// %1 = tensor.insert_slice %0 ... +/// : tensor<4x8xf32> to tensor +/// scf.yield %1: tensor +/// } -> tensor +/// scf.for (%j, %k) { +/// %st0 = tensor.extract_slice %packed [%k, 0, 0][1, 4, 8][1, 1, 1] : +/// tensor to tensor<4x8xf32> +/// compute(%st0) +/// } +/// } +/// ``` +FailureOr +hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops, + ArrayRef transposeVector, + tensor::PadOp &hoistedOp, + SmallVectorImpl &transposeOps); + +/// Apply padding and hoisting to `linalgOp` according to the configuration +/// specified in `options`. +FailureOr padAndHoistLinalgOp(RewriterBase &rewriter, + LinalgOp linalgOp, + LinalgPaddingOptions options); /// Split the given `op` into two parts along the given iteration space /// `dimension` at the specified `splitPoint`, and return the two parts. diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -51,11 +51,13 @@ } void mlir::getForwardSlice(Operation *op, SetVector *forwardSlice, - TransitiveFilter filter) { + TransitiveFilter filter, bool inclusive) { getForwardSliceImpl(op, forwardSlice, filter); - // Don't insert the top level operation, we just queried on it and don't - // want it in the results. - forwardSlice->remove(op); + if (!inclusive) { + // Don't insert the top level operation, we just queried on it and don't + // want it in the results. + forwardSlice->remove(op); + } // Reverse to get back the actual topological order. // std::reverse does not work out of the box on SetVector and I want an @@ -65,7 +67,7 @@ } void mlir::getForwardSlice(Value root, SetVector *forwardSlice, - TransitiveFilter filter) { + TransitiveFilter filter, bool inclusive) { for (Operation *user : root.getUsers()) getForwardSliceImpl(user, forwardSlice, filter); @@ -114,27 +116,30 @@ void mlir::getBackwardSlice(Operation *op, SetVector *backwardSlice, - TransitiveFilter filter) { + TransitiveFilter filter, bool inclusive) { getBackwardSliceImpl(op, backwardSlice, filter); - // Don't insert the top level operation, we just queried on it and don't - // want it in the results. - backwardSlice->remove(op); + if (!inclusive) { + // Don't insert the top level operation, we just queried on it and don't + // want it in the results. + backwardSlice->remove(op); + } } void mlir::getBackwardSlice(Value root, SetVector *backwardSlice, - TransitiveFilter filter) { + TransitiveFilter filter, bool inclusive) { if (Operation *definingOp = root.getDefiningOp()) { - getBackwardSlice(definingOp, backwardSlice, filter); + getBackwardSlice(definingOp, backwardSlice, filter, inclusive); return; } Operation *bbAargOwner = root.cast().getOwner()->getParentOp(); - getBackwardSlice(bbAargOwner, backwardSlice, filter); + getBackwardSlice(bbAargOwner, backwardSlice, filter, inclusive); } SetVector mlir::getSlice(Operation *op, TransitiveFilter backwardFilter, - TransitiveFilter forwardFilter) { + TransitiveFilter forwardFilter, + bool inclusive) { SetVector slice; slice.insert(op); @@ -145,12 +150,12 @@ auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); - getBackwardSlice(currentOp, &backwardSlice, backwardFilter); + getBackwardSlice(currentOp, &backwardSlice, backwardFilter, inclusive); slice.insert(backwardSlice.begin(), backwardSlice.end()); // Compute and insert the forwardSlice starting from currentOp. forwardSlice.clear(); - getForwardSlice(currentOp, &forwardSlice, forwardFilter); + getForwardSlice(currentOp, &forwardSlice, forwardFilter, inclusive); slice.insert(forwardSlice.begin(), forwardSlice.end()); ++currentIndex; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1718,18 +1718,19 @@ transposePaddings.push_back( extractFromI64ArrayAttr(transposeVector.cast())); - LinalgPaddingOptions paddingOptions; - paddingOptions.setPaddingValues(paddingValues); - paddingOptions.setPaddingDimensions( - extractFromI64ArrayAttr(getPaddingDimensions())); - paddingOptions.setPackPaddings(packPaddings); - paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); - paddingOptions.setTransposePaddings(transposePaddings); - IRRewriter rewriter(target->getContext()); - FailureOr result = padLinalgOp(rewriter, target, paddingOptions); + LinalgOp paddedOp; + FailureOr> result = rewriteAsPaddedOp( + rewriter, target, extractFromI64ArrayAttr(getPaddingDimensions()), + paddingValues, packPaddings, paddedOp); if (succeeded(result)) { - results.push_back(result->getOperation()); + // We need to perform our own replacement here because this API is still + // used in patterns that "pad and hoist", for which the replacement values + // need to be different. + // TODO: clean this up and stop "pad and hoist" behavior more globally now + // that we have more composable abstractions. + rewriter.replaceOp(target, *result); + results.push_back(paddedOp); return DiagnosedSilenceableFailure::success(); } @@ -1756,15 +1757,6 @@ << getPaddingDimensions(); } - SmallVector hoistPaddings = - extractFromI64ArrayAttr(getHoistPaddings()); - if (any_of(hoistPaddings, - [](int64_t hoistPadding) { return hoistPadding < 0; })) { - return emitOpError() - << "expects hoist_paddings to contain positive integers, found " - << getHoistPaddings(); - } - ArrayAttr transposes = getTransposePaddings(); for (Attribute attr : transposes) { SmallVector transpose = extractFromI64ArrayAttr(attr); @@ -1779,6 +1771,44 @@ return success(); } +//===---------------------------------------------------------------------===// +// HoistPadOp +//===---------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::HoistPadOp::applyToOne(tensor::PadOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + IRRewriter rewriter(target->getContext()); + tensor::PadOp hoistedPadOp; + SmallVector transposeOps; + // TODO: Pass rewriter down to hoistPaddingOnTensors, in a followup commit. + FailureOr result = hoistPaddingOnTensors( + target, getNumLoops(), getTranspose(), hoistedPadOp, transposeOps); + if (succeeded(result)) { + // We need to perform our own replacement here because this API is still + // used in patterns that "pad and hoist", for which the replacement values + // need to be different. + // TODO: clean this up and stop "pad and hoist" behavior more globally now + // that we have more composable abstractions. + rewriter.replaceOp(target, *result); + results.push_back(hoistedPadOp); + return DiagnosedSilenceableFailure::success(); + } + return emitDefaultSilenceableFailure(target); +} + +LogicalResult transform::HoistPadOp::verify() { + ArrayRef transpose = getTranspose(); + auto sequence = llvm::to_vector(llvm::seq(0, transpose.size())); + if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), + transpose.end())) { + return emitOpError() << "expects transpose to be a permutation, found " + << getTranspose(); + } + return success(); +} + //===----------------------------------------------------------------------===// // PromoteOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -10,7 +10,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -25,7 +24,9 @@ #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" + #include "mlir/IR/Matchers.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" @@ -165,6 +166,30 @@ return transposedTensorType; } +// Get all the ops in the backwards slice starting from `padOp` and that +// are dominated by the outermost enclosing loop. +// This also requires tracking ops defining values used in the region but +// defined above. +static void computeBackwardSlice(tensor::PadOp padOp, + scf::ForOp outermostEnclosingForOp, + SetVector &backwardSlice) { + DominanceInfo domInfo(outermostEnclosingForOp); + auto filter = [&](Operation *op) { + return domInfo.dominates(outermostEnclosingForOp, op) && + !padOp->isProperAncestor(op); + }; + // First, add the ops required to compute the region to the backwardSlice. + SetVector valuesDefinedAbove; + getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(), + valuesDefinedAbove); + for (Value v : valuesDefinedAbove) { + getBackwardSlice(v, &backwardSlice, filter, /*inclusive=*/true); + } + // Then, add the backward slice from padOp itself. + getBackwardSlice(padOp.getOperation(), &backwardSlice, filter, + /*inclusive=*/true); +} + HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) { valid = false; @@ -218,16 +243,9 @@ return; } - // Get all the ops in the backwards slice starting from `padOp` and that - // are dominated by the outermost enclosing loop. - DominanceInfo domInfo(outermostEnclosingForOp); - getBackwardSlice(padOp.getOperation(), &backwardSlice, [&](Operation *op) { - return domInfo.dominates(outermostEnclosingForOp, op); - }); - if (backwardSlice.empty()) + computeBackwardSlice(padOp, outermostEnclosingForOp, backwardSlice); + if (backwardSlice.size() <= 1) return; - // Add `padOp` itself to the backward slice. - backwardSlice.insert(padOp.getOperation()); // Remove all ops in the backward slice that are not used to index the padded // tensor. In particular, keep `padOp`, `sliceOp`, and the loop and @@ -394,9 +412,11 @@ ValueRange{ivVal, lbVal, stepVal}); } -FailureOr mlir::linalg::hoistPaddingOnTensors( - tensor::PadOp opToHoist, int numLoops, ArrayRef transposeVector, - tensor::PadOp &hoistedOp, SmallVectorImpl &transposeOps) { +FailureOr +mlir::linalg::hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops, + ArrayRef transposeVector, + tensor::PadOp &hoistedOp, + SmallVectorImpl &transposeOps) { LLVM_DEBUG(DBGS() << "Try to hoist " << *(opToHoist) << " by " << numLoops << " loops\n"); HoistingAnalysis analysis(opToHoist, numLoops); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -54,7 +53,7 @@ /// dimensions `paddingDimensions` and return the tensor::PadOp result if /// padding succeeds or failure otherwise. static FailureOr padOperandToSmallestStaticBoundingBox( - OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, + RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, ArrayRef paddingDimensions, ArrayRef paddingValues, ArrayRef packPaddings) { AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand); @@ -79,14 +78,15 @@ return opOperand->get(); // Fail if `paddingValues` specifies no padding value. - if (opOperand->getOperandNumber() >= paddingValues.size()) - return failure(); + if (opOperand->getOperandNumber() >= paddingValues.size()) { + return rewriter.notifyMatchFailure(opToPad, "no padding value specified"); + } Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()]; - Type paddingType = b.getType(); + Type paddingType = rewriter.getType(); if (auto typedAttr = paddingAttr.dyn_cast()) paddingType = typedAttr.getType(); - Value paddingValue = - b.create(opToPad.getLoc(), paddingType, paddingAttr); + Value paddingValue = rewriter.create( + opToPad.getLoc(), paddingType, paddingAttr); // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. OpOperand *currOpOperand = opOperand; @@ -98,8 +98,14 @@ // Fail if `currOpOperand` is not defined by an ExtractSliceOp or EmptyOp. auto sliceOp = currOpOperand->get().getDefiningOp(); auto emptyOp = currOpOperand->get().getDefiningOp(); - if (!sliceOp && !emptyOp) - return failure(); + if (!sliceOp && !emptyOp) { + // TODO: may want to add support for going through loop iter args. + // This is not strictly necessary as we can pad before hoisting but it would + // make the system overall more resilient to minor transformation + // reorderings. + return rewriter.notifyMatchFailure( + opToPad, "not defined by an extractSlice or emptyOp"); + } llvm::SmallBitVector droppedDims; SmallVector mixedSizes; @@ -135,8 +141,9 @@ FailureOr upperBound = getConstantUpperBoundForIndex(en.value().get()); if (failed(upperBound)) { - LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); - return failure(); + LLVM_DEBUG(DBGS() << "count not compute a bonding box for padding"); + return rewriter.notifyMatchFailure( + opToPad, "count not compute a bonding box for padding"); } paddedShape[shapeIdx++] = *upperBound; } @@ -146,7 +153,7 @@ // Pad the operand to the bounding box defined by `paddedShape`. auto paddedTensorType = RankedTensorType::get( paddedShape, getElementTypeOrSelf(opOperand->get())); - return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType, + return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType, opOperand->get(), paddingValue, nofold); } @@ -164,7 +171,7 @@ // rewriteAsPaddedOp transformation. //===----------------------------------------------------------------------===// FailureOr> -linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, +linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, ArrayRef paddingDimensions, ArrayRef paddingValues, ArrayRef packPaddings, LinalgOp &paddedOp) { @@ -174,15 +181,17 @@ assert(opToPad.hasTensorSemantics() && "expected operation to have tensor semantics"); - OpBuilder::InsertionGuard g(b); + OpBuilder::InsertionGuard g(rewriter); // Set IP after op because we also take the dims of the original output. - b.setInsertionPointAfter(opToPad); + rewriter.setInsertionPointAfter(opToPad); + // Make a copy of the shaped operands and update it. SmallVector newOperands; newOperands.reserve(opToPad->getNumOperands()); for (OpOperand &opOperand : opToPad->getOpOperands()) { FailureOr paddedOperand = padOperandToSmallestStaticBoundingBox( - b, opToPad, &opOperand, paddingDimensions, paddingValues, packPaddings); + rewriter, opToPad, &opOperand, paddingDimensions, paddingValues, + packPaddings); // Exit if `paddingDimensions` cannot be bounded statically. if (failed(paddedOperand)) return failure(); @@ -191,7 +200,7 @@ SmallVector> reifiedResultShapes; if (failed(cast(opToPad.getOperation()) - .reifyResultShapes(b, reifiedResultShapes))) + .reifyResultShapes(rewriter, reifiedResultShapes))) return failure(); assert(reifiedResultShapes.size() == opToPad->getNumResults() && "expected same number of results"); @@ -199,25 +208,26 @@ // Clone `opToPad` to operate on the statically padded shapes. auto resultTensorTypes = ValueRange(newOperands).take_back(opToPad.getNumDpsInits()).getTypes(); - paddedOp = clone(b, opToPad, resultTensorTypes, newOperands); + // clone **should** properly notify the rewriter. + paddedOp = clone(rewriter, opToPad, resultTensorTypes, newOperands); // Recover the slice out of the new static results. This keeps the original // linalg op around because it uses the dims of the original results. - SmallVector paddedSubviewResults; - paddedSubviewResults.reserve(opToPad->getNumResults()); + SmallVector paddedSubtensorResults; + paddedSubtensorResults.reserve(opToPad->getNumResults()); for (const auto &en : llvm::enumerate(paddedOp->getResults())) { Value paddedResult = en.value(); int64_t resultNumber = en.index(); int64_t rank = paddedResult.getType().cast().getRank(); - SmallVector offsets(rank, b.getIndexAttr(0)); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes; for (Value v : reifiedResultShapes[resultNumber]) sizes.push_back(getAsOpFoldResult(v)); - SmallVector strides(rank, b.getIndexAttr(1)); - paddedSubviewResults.push_back(b.create( + SmallVector strides(rank, rewriter.getIndexAttr(1)); + paddedSubtensorResults.push_back(rewriter.create( loc, paddedResult, offsets, sizes, strides)); } - return paddedSubviewResults; + return paddedSubtensorResults; } //===----------------------------------------------------------------------===// @@ -253,9 +263,9 @@ // pad transformation. //===----------------------------------------------------------------------===// -FailureOr mlir::linalg::padLinalgOp(RewriterBase &rewriter, - LinalgOp linalgOp, - LinalgPaddingOptions options) { +FailureOr +mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, + LinalgPaddingOptions options) { if (!linalgOp.hasTensorSemantics()) return rewriter.notifyMatchFailure( linalgOp, "only applies to Linalg ops with tensor semantics"); @@ -753,8 +763,9 @@ return *this; } -/// Linalg padding pattern. - +/// +/// Padding pattern. +/// mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( MLIRContext *context, LinalgPaddingOptions options, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), @@ -762,7 +773,7 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( LinalgOp op, PatternRewriter &rewriter) const { - return padLinalgOp(rewriter, op, options); + return padAndHoistLinalgOp(rewriter, op, options); } LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( @@ -770,8 +781,10 @@ return vectorizeCopy(rewriter, copyOp); } -/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to +/// +/// Pattern to rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to /// initialize with pad_val) and GenericOp (to copy contents). +/// LogicalResult PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const { diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -171,7 +171,6 @@ Sequence[Attribute]]] = None, padding_dimensions: OptionalIntList = None, pack_paddings: OptionalIntList = None, - hoist_paddings: OptionalIntList = None, transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ ArrayAttr, IntOrAttrList]]]] = None, loc=None, @@ -180,7 +179,6 @@ padding_values_attr = _get_array_attr(padding_values) padding_dimensions_attr = _get_int_array_attr(padding_dimensions) pack_paddings_attr = _get_int_array_attr(pack_paddings) - hoist_paddings_attr = _get_int_array_attr(hoist_paddings) transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) super().__init__( pdl_operation_type, @@ -188,7 +186,6 @@ padding_values=padding_values_attr, padding_dimensions=padding_dimensions_attr, pack_paddings=pack_paddings_attr, - hoist_paddings=hoist_paddings_attr, transpose_paddings=transpose_paddings_attr, loc=loc, ip=ip) diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir @@ -0,0 +1,151 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter -canonicalize -split-input-file --verify-diagnostics %s | FileCheck %s + +func.func @pad_and_hoist_rhs( + %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) + -> tensor<24x25xf32> +{ + // expected-note @below {{payload operation}} + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + + + %matmul_l1, %loops_l1 = transform.structured.tile_to_scf_for %matmul [5] + + %matmul_padded = transform.structured.pad %matmul_l1 { + padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2] + } + + // In this case, the pad op is actually empty: we only tile the first dimension + // and it does not have an impact on the RHS operand. + // expected-error @below {{incompatible payload operation name}} + %pad = transform.get_producer_of_operand %matmul_padded[1] + : (!pdl.operation) -> !transform.op<"tensor.pad"> + + // We do not even reach this transform op. + transform.structured.hoist_pad %pad by 1 loops + : (!transform.op<"tensor.pad">) -> !pdl.operation +} + +// ----- + +func.func @pad_and_hoist_init( + %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) + -> tensor<24x25xf32> +{ + // expected-note @below {{when applied to this op}} + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + + + %matmul_l1, %loops_l1 = transform.structured.tile_to_scf_for %matmul [5] + + %matmul_padded = transform.structured.pad %matmul_l1 { + padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2] + } + + %pad = transform.get_producer_of_operand %matmul_padded[2] + : (!pdl.operation) -> !transform.op<"tensor.pad"> + + // We do not know yet how to hoist the init. + // expected-error @below {{transform.structured.hoist_pad failed to apply}} + transform.structured.hoist_pad %pad by 1 loops + : (!transform.op<"tensor.pad">) -> !pdl.operation +} + +// ----- + +// CHECK-LABEL: pad_and_hoist_lhs +func.func @pad_and_hoist_lhs( + %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) + -> tensor<24x25xf32> +{ + // CHECK: %[[PACKED:.*]] = scf.for %{{.*}} -> (tensor<5x5x12xf32>) { + // CHECK: tensor.pad %{{.*}} + // CHECK: : tensor to tensor<5x12xf32> + // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0, 0] [1, 5, 12] [1, 1, 1] + // CHECK-SAME: : tensor<5x12xf32> into tensor<5x5x12xf32> + // CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) { + // CHECK: %[[PADDED:.*]] = tensor.extract_slice %[[PACKED]][%{{.*}}, 0, 0] [1, 5, 12] [1, 1, 1] + // CHECK-SAME: : tensor<5x5x12xf32> to tensor<5x12xf32> + // CHECK: linalg.matmul ins(%[[PADDED]] + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + + + %matmul_l1, %loops_l1 = transform.structured.tile_to_scf_for %matmul [5] + + %matmul_padded = transform.structured.pad %matmul_l1 { + padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2] + } + + %pad = transform.get_producer_of_operand %matmul_padded[0] + : (!pdl.operation) -> !pdl.operation + + transform.structured.hoist_pad %pad by 1 loops + : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// CHECK-LABEL: pad_and_hoist_lhs_transpose +func.func @pad_and_hoist_lhs_transpose( + %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) + -> tensor<24x25xf32> +{ + // CHECK: %[[PACKED:.*]] = scf.for %{{.*}} -> (tensor<5x12x5xf32>) { + // CHECK: tensor.pad %{{.*}} + // CHECK: : tensor to tensor<5x12xf32> + // CHECK: linalg.generic + // CHECK: -> tensor<12x5xf32> + // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0, 0] [1, 12, 5] [1, 1, 1] + // CHECK-SAME: : tensor<12x5xf32> into tensor<5x12x5xf32> + // CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) { + // CHECK: %[[PADDED:.*]] = tensor.extract_slice %[[PACKED]][%{{.*}}, 0, 0] [1, 12, 5] [1, 1, 1] + // CHECK-SAME: : tensor<5x12x5xf32> to tensor<12x5xf32> + // CHECK: %[[TRANSPOSED:.*]] = linalg.generic + // CHECK: -> tensor<5x12xf32> + // CHECK: linalg.matmul ins(%[[TRANSPOSED]] + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + + + %matmul_l1, %loops_l1 = transform.structured.tile_to_scf_for %matmul [5] + + %matmul_padded = transform.structured.pad %matmul_l1 { + padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2] + } + + %pad = transform.get_producer_of_operand %matmul_padded[0] + : (!pdl.operation) -> !pdl.operation + + transform.structured.hoist_pad %pad by 1 loops, transpose by [1, 0] + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -34,7 +34,11 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} + %1 = transform.structured.pad %0 { + padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + pack_paddings=[1, 1, 0] + } } // ----- @@ -72,7 +76,11 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} + %1 = transform.structured.pad %0 { + padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + pack_paddings=[1, 1, 0] + } } // ----- @@ -89,7 +97,11 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation // expected-error @below {{op expects a padding value of type 'f32', got 0 : i32}} - %1 = transform.structured.pad %0 {padding_values=[0: i32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} + %1 = transform.structured.pad %0 { + padding_values=[0: i32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + pack_paddings=[1, 1, 0] + } } // ----- @@ -106,7 +118,11 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation // expected-error @below {{expects a padding that parses to 'f32', got "foo"}} - %1 = transform.structured.pad %0 {padding_values=["foo", 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} + %1 = transform.structured.pad %0 { + padding_values=["foo", 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + pack_paddings=[1, 1, 0] + } } // ----- @@ -125,5 +141,9 @@ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation // This error is silenceable and is not reported by this transform // {{transform.structured.pad failed to apply}} - %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} + %1 = transform.structured.pad %0 { + padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + pack_paddings=[1, 1, 0] + } } diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir --- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir @@ -24,14 +24,6 @@ // ----- -transform.sequence failures(propagate) { -^bb0(%arg0: !pdl.operation): - // expected-error@below {{expects hoist_paddings to contain positive integers, found [1, -7]}} - transform.structured.pad %arg0 {hoist_paddings=[1, -7]} -} - -// ----- - transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error@below {{expects transpose_paddings to be a permutation, found [1, 1]}} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -84,7 +84,7 @@ # CHECK-DAG: padding_values = [4.200000e+01 : f32] # CHECK-DAG: padding_dimensions = [1] # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]] - # (hoist_paddings and pack_paddings have default values) + # (pack_paddings has default values) @run def testScalarize():