diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h b/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h @@ -19,12 +19,17 @@ } // namespace tensor namespace linalg { +class GenericOp; /// Mechanically hoist padding operations on tensors by `numLoops` into a new, /// generally larger tensor. This achieves packing of multiple padding ops into -/// a larger tensor. On success, `padTensorOp` is replaced by the cloned version +/// a larger tensor. On success, `opToHoist` is replaced by the cloned version /// in the packing loop so the caller can continue reasoning about the padding -/// operation. +/// operation. If `transposeVector` is non-empty, hoist padding introduces a +/// GenericOp to transpose the padded tensor before inserting it into the packed +/// tensor. A `transposeVector` can change the storage order of the padded +/// tensor but does not change the order of the pack or compute loops. +/// /// /// Example in pseudo-mlir: /// ======================= @@ -33,7 +38,7 @@ /// ``` /// scf.for (%i, %j, %k) /// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor -/// %0 = linalg.pad_tensor %st0 low[0, 0] high[...] { +/// %0 = tensor.pad %st0 low[0, 0] high[...] { /// ^bb0( ... ): /// linalg.yield %pad /// } : tensor to tensor<4x8xf32> @@ -47,7 +52,7 @@ /// %packed_init = linalg.init_tensor range(%j) : tensor /// %packed = scf.for (%k) iter_args(%p : %packed_init) { /// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor -/// %0 = linalg.pad_tensor %st0 low[0, 0] high[...] { +/// %0 = tensor.pad %st0 low[0, 0] high[...] { /// ^bb0( ... ): /// linalg.yield %pad /// } : tensor to tensor<4x8xf32> @@ -62,8 +67,9 @@ /// } /// } /// ``` -FailureOr hoistPaddingOnTensors(tensor::PadOp opToHoist, int numLoops, - tensor::PadOp &hoistedOp); +FailureOr hoistPaddingOnTensors( + tensor::PadOp opToHoist, int numLoops, ArrayRef transposeVector, + tensor::PadOp &hoistedOp, SmallVectorImpl &transposeOps); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -484,14 +484,19 @@ using PaddingValueComputationFunction = std::function(OpBuilder &, OpOperand &)>; -/// Callback returning true if the pad tensor operation defining the given -/// OpOperand shall be marked as nofold to enable packing. +/// Callback returning true if the PadOp defining the given OpOperand shall be +/// marked as nofold to enable packing. using PaddingNoFoldComputationFunction = std::function; -/// Callback returning the number of loops to hoist the pad tensor operation -/// defining the given OpOperand. +/// Callback returning the number of loops to hoist the PadOp defining the given +/// OpOperand. using PaddingHoistComputationFunction = std::function; +/// Callback returning the transpose vector used to permute the result tensor +/// dimensions of the PadOp defining the given OpOperand. +using PaddingTransposeComputationFunction = + std::function(OpOperand &)>; + struct LinalgPaddingOptions { /// Callback returning the padding value to use for a given OpOperand or /// failure for no padding. Padding operations are introduced if @@ -506,10 +511,10 @@ return *this; } - /// Callback returning true if the pad tensor operation defining the given - /// OpOperand shall be marked as nofold to enable packing. A padding operation - /// is only marked nofold if `paddingNoFoldComputationFunction` is set and - /// returns true. Otherwise, the nofold attribute is set to false. + /// Callback returning true if the PadOp defining the given OpOperand shall be + /// marked as nofold to enable packing. A padding operation is only marked + /// nofold if `paddingNoFoldComputationFunction` is set and returns true. + /// Otherwise, the nofold attribute is set to false. PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr; LinalgPaddingOptions & @@ -518,8 +523,8 @@ return *this; } - /// Callback returning the number of loops to hoist the pad tensor operation - /// defining the given OpOperand. + /// Callback returning the number of loops to hoist the PadOp defining the + /// given OpOperand. PaddingHoistComputationFunction paddingHoistComputationFunction = nullptr; LinalgPaddingOptions & @@ -527,6 +532,17 @@ paddingHoistComputationFunction = std::move(fun); return *this; } + + /// Callback returning the transpose vector used to permute the result tensor + /// dimensions of the PadOp defining the given OpOperand. + PaddingTransposeComputationFunction paddingTransposeComputationFunction = + nullptr; + + LinalgPaddingOptions &setPaddingTransposeComputationFunction( + PaddingTransposeComputationFunction fun) { + paddingTransposeComputationFunction = std::move(fun); + return *this; + } }; struct LinalgTilingAndFusionOptions { diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -117,18 +117,24 @@ /// Example: /// ``` /// %0 = tensor.extract_slice %arg0 [%iv0, %iv1] [%sz0, %sz1] -/// %1 = linalg.pad_tensor %0 low[0, 0] high[...] { linalg.yield %cst } +/// %1 = tensor.pad %0 low[0, 0] high[...] { tensor.yield %cst } /// %2 = linalg.matmul ins(...) outs(%1) /// %3 = tensor.extract_slice %2 [0, 0] [%sz0, %sz1] /// ``` /// makeComposedPadHighOp(source=%3, pad=%cst) returns %2 /// makeComposedPadHighOp(source=%3, pad=%other_cst) returns %4 /// ``` -/// %4 = linalg.pad_tensor %3 low[0, 0] high[...] { linalg.yield %other_cst } +/// %4 = tensor.pad %3 low[0, 0] high[...] { tensor.yield %other_cst } /// ``` Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold); +/// Returns a GenericOp that tansposes `inputTensor` into `outputTensor` using +/// `transposeVector` to permute the `inputTensor` dimensions. +GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, + Value outputTensor, + ArrayRef transposeVector); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -53,32 +53,32 @@ /// 8. There is no enclosing scf::ForOp that indexes the padded data. /// Other cases succeed and will trigger hoisting of the pad op. struct HoistingAnalysis { - HoistingAnalysis(tensor::PadOp padTensorOp, int numLoops); + HoistingAnalysis(tensor::PadOp padOp, int numLoops); bool isValid() { return valid; } /// Footprint of the packedTensor, computed from the packingLoops. SmallVector getPackedTensorSizes(ImplicitLocOpBuilder &b); - /// The outermost loop, determined by `nLevels` above which `padTensorOp` will + /// The outermost loop, determined by `nLevels` above which `padOp` will /// be hoisted. scf::ForOp outermostEnclosingForOp; - /// Backward slice rooted at `padTensorOp` and nested under + /// Backward slice rooted at `padOp` and nested under /// `outermostEnclosingForOp`. SetVector backwardSlice; - /// The scf::ForOp immediately enclosing `padTensorOp` such that: + /// The scf::ForOp immediately enclosing `padOp` such that: /// 1. they are nested under `outermostEnclosingForOp` (inclusive) /// 2. whose induction variable is used, directly or indirectly, in the - /// computation of `padTensorOp`. + /// computation of `padOp`. /// The span of these loops determines the footprint of the packed tensor. SmallVector packingLoops; private: - /// Drop any non-index dependencies of `padTensorOp` and `sliceOp` from + /// Drop any non-index dependencies of `padOp` and `sliceOp` from /// `backwardSlice`. The method follows the use-def chains of the index - /// operands consumed by `padTensorOp` and `sliceOp` and drops the operations + /// operands consumed by `padOp` and `sliceOp` and drops the operations /// not part of this index computation. Afterwards, the filtered /// `backwardSlice` contains only the loops whose induction variable is used, /// directly or indirectly, to index the padded tensor. The method returns @@ -94,24 +94,24 @@ /// %ubi = affine.min #map(%i) /// %ubj = affine.min #map(%j) /// %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj] - /// %padded_slice = linalg.pad_tensor %slice + /// %padded_slice = tensor.pad %slice /// ``` /// dropNonIndexDependencies(%padded_slice, %slice) /// removes [scf.for %k, linalg.fill(%cst, %arg1)] from backwardSlice. - LogicalResult dropNonIndexDependencies(tensor::PadOp padTensorOp, + LogicalResult dropNonIndexDependencies(tensor::PadOp padOp, tensor::ExtractSliceOp sliceOp); /// Encodes whether the analysis is valid and hoisting can proceed. bool valid; }; -/// Return true if all uses of `padTensorOp` are an input tensor of some +/// Return true if all uses of `padOp` are an input tensor of some /// LinalgOp. -static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padTensorOp) { - for (OpOperand &use : padTensorOp.result().getUses()) { +static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp) { + for (OpOperand &use : padOp.result().getUses()) { auto linalgUser = dyn_cast(use.getOwner()); if (!linalgUser || !linalgUser.isInputTensor(&use)) { - LLVM_DEBUG(DBGS() << "Found a use of " << *(padTensorOp) + LLVM_DEBUG(DBGS() << "Found a use of " << *(padOp) << "\nthat is not an input tensor of a LinalgOp, " << "cannot hoist\n" << *(use.getOwner()) << "\n"); @@ -126,12 +126,12 @@ /// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm. /// Control-flow and other containing ops with regions are not modeled atm. static void -getAtMostNEnclosingLoops(tensor::PadOp padTensorOp, int nLevels, +getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels, SmallVector &reverseEnclosingLoops) { - AsmState state(padTensorOp->getParentOfType()); + AsmState state(padOp->getParentOfType()); (void)state; scf::ForOp outermostEnclosingForOp = nullptr; - Operation *nextEnclosingOp = padTensorOp->getParentOp(); + Operation *nextEnclosingOp = padOp->getParentOp(); while (nLevels-- > 0 && (outermostEnclosingForOp = dyn_cast(nextEnclosingOp))) { LLVM_DEBUG( @@ -143,17 +143,38 @@ } } -HoistingAnalysis::HoistingAnalysis(tensor::PadOp padTensorOp, int numLoops) { +/// Returns the transposed `rankedTensorType` if `transposeVector` is non-empty. +/// Fail if `transposeVector` is no permutation matching the tensor rank. +static FailureOr +computeTransposedType(RankedTensorType rankedTensorType, + ArrayRef transposeVector) { + if (transposeVector.empty()) + return rankedTensorType; + if (!isPermutation(transposeVector) || + transposeVector.size() != static_cast(rankedTensorType.getRank())) + return failure(); + + SmallVector transposedShape(rankedTensorType.getShape().begin(), + rankedTensorType.getShape().end()); + applyPermutationToVector(transposedShape, transposeVector); + + using RTTBuilder = RankedTensorType::Builder; + RankedTensorType transposedTensorType = + RTTBuilder(rankedTensorType).setShape(transposedShape); + return transposedTensorType; +} + +HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) { valid = false; - // Bail on any use that isn't an input of a Linalg op. + // Bail on any use that isn't an input of a LinalgOp. // Hoisting of inplace updates happens after vectorization. - if (!isOnlyUsedAsInputOfLinalgOp(padTensorOp)) + if (!isOnlyUsedAsInputOfLinalgOp(padOp)) return; // Get at most `numLoops` of immediately enclosing loops. SmallVector reverseEnclosingLoops; - getAtMostNEnclosingLoops(padTensorOp, numLoops, reverseEnclosingLoops); + getAtMostNEnclosingLoops(padOp, numLoops, reverseEnclosingLoops); if (reverseEnclosingLoops.empty()) { LLVM_DEBUG(DBGS() << "No immediately enclosing loop -> skip\n"); return; @@ -161,7 +182,7 @@ outermostEnclosingForOp = reverseEnclosingLoops.back(); - // Get the `sliceOp` that defines the source tensor of `padTensorOp` and + // Get the `sliceOp` that defines the source tensor of `padOp` and // check its source is defined outside of the outermost loop. This check // ensures the padded data is available for packing before entering the // outermost enclosing loop. @@ -174,9 +195,9 @@ // scf.for %j // scf.for %k // %slice = tensor.extract_slice %source [%i, %j] - // %padded_slice = linalg.pad_tensor %slice + // %padded_slice = tensor.pad %slice // ``` - auto sliceOp = padTensorOp.source().getDefiningOp(); + auto sliceOp = padOp.source().getDefiningOp(); if (!sliceOp) { LLVM_DEBUG(DBGS() << "Cannot find the extract slice op -> skip\n"); return; @@ -186,32 +207,31 @@ return; } - // Check the region of `padTensorOp` depends on a constant only. Adding + // Check the region of `padOp` depends on a constant only. Adding // hoisting support for arbitrary padding regions would require cloning all // dependencies captured by the padding region. - Value paddingValue = padTensorOp.getConstantPaddingValue(); + Value paddingValue = padOp.getConstantPaddingValue(); if (!paddingValue || !isa_and_nonnull(paddingValue.getDefiningOp())) { LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> skip\n"); return; } - // Get all the ops in the backwards slice starting from `padTensorOp` and that + // Get all the ops in the backwards slice starting from `padOp` and that // are dominated by the outermost enclosing loop. DominanceInfo domInfo(outermostEnclosingForOp); - getBackwardSlice(padTensorOp.getOperation(), &backwardSlice, - [&](Operation *op) { - return domInfo.dominates(outermostEnclosingForOp, op); - }); + getBackwardSlice(padOp.getOperation(), &backwardSlice, [&](Operation *op) { + return domInfo.dominates(outermostEnclosingForOp, op); + }); if (backwardSlice.empty()) return; - // Add `padTensorOp` itself to the backward slice. - backwardSlice.insert(padTensorOp.getOperation()); + // Add `padOp` itself to the backward slice. + backwardSlice.insert(padOp.getOperation()); // Remove all ops in the backward slice that are not used to index the padded - // tensor. In particular, keep `padTensorOp`, `sliceOp`, and the loop and + // tensor. In particular, keep `padOp`, `sliceOp`, and the loop and // affine operations used for the index computation. - if (failed(dropNonIndexDependencies(padTensorOp, sliceOp))) + if (failed(dropNonIndexDependencies(padOp, sliceOp))) return; // Add only the loops part of the filtered `backwardSlice` to the packing @@ -232,7 +252,7 @@ } LogicalResult -HoistingAnalysis::dropNonIndexDependencies(tensor::PadOp padTensorOp, +HoistingAnalysis::dropNonIndexDependencies(tensor::PadOp padOp, tensor::ExtractSliceOp sliceOp) { // Set of all values used for index computation. SetVector indexEdges; @@ -252,7 +272,7 @@ }); }; - // Starting from `padTensorOp` and `sliceOp` walk the use-def edges of index + // Starting from `padOp` and `sliceOp` walk the use-def edges of index // type in `backwardSlice`. Add the index operands of an operation to // `indexEdges` and remove all operations from `backwardSlice` that are not // part of the index computation. @@ -267,16 +287,16 @@ // %ubi = affine.min #map(%i) // %ubj = affine.min #map(%j) // %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj] - // %padded_slice = linalg.pad_tensor %slice + // %padded_slice = tensor.pad %slice // ``` // After iterating `backwardSlice` we obtain: // indexEdges = [%i, %j, %ubi, %ubj] // backwardSlice = backwardSlice / [linalg.fill(%cst, %arg1), scf.for %k] SetVector operationsToRemove; for (Operation *op : llvm::reverse(backwardSlice)) { - // Add the index operands of `padTensorOp` and `sliceOp` to start the + // Add the index operands of `padOp` and `sliceOp` to start the // exploration of the index computation. - if (op == padTensorOp || op == sliceOp) { + if (op == padOp || op == sliceOp) { addIndexOperandsToIndexEdges(op); continue; } @@ -310,7 +330,7 @@ continue; } // Remove all other operations not used by the index computation. An - // exception are constant operations that may be used by `padTensorOp`. + // exception are constant operations that may be used by `padOp`. if (!isa(op)) operationsToRemove.insert(op); } @@ -373,9 +393,9 @@ ValueRange{ivVal, lbVal, stepVal}); } -FailureOr mlir::linalg::hoistPaddingOnTensors(tensor::PadOp opToHoist, - int numLoops, - tensor::PadOp &hoistedOp) { +FailureOr mlir::linalg::hoistPaddingOnTensors( + tensor::PadOp opToHoist, int numLoops, ArrayRef transposeVector, + tensor::PadOp &hoistedOp, SmallVectorImpl &transposeOps) { LLVM_DEBUG(DBGS() << "Try to hoist " << *(opToHoist) << " by " << numLoops << " loops\n"); HoistingAnalysis analysis(opToHoist, numLoops); @@ -396,14 +416,20 @@ RankedTensorType paddedTensorType = opToHoist.getResultType(); int paddedRank = paddedTensorType.getRank(); - // Create the packed tensor into which we amortize + // Compute the type of the transposed padded tensor. + FailureOr transposedTensorType = + computeTransposedType(paddedTensorType, transposeVector); + if (failed(transposedTensorType)) + return failure(); + + // Create the packed tensor into which we amortize // padding. SmallVector packedShape(nPackedLoops, ShapedType::kDynamicSize); // TODO: go grab dims when necessary, for now tensor::PadOp returns a static // tensor. - llvm::append_range(packedShape, paddedTensorType.getShape()); - auto packedTensorType = - RankedTensorType::get(packedShape, paddedTensorType.getElementType()); + llvm::append_range(packedShape, transposedTensorType->getShape()); + auto packedTensorType = RankedTensorType::get( + packedShape, transposedTensorType->getElementType()); Value packedTensor = b.create( loc, dynamicTensorSizes, packedTensorType.getShape(), packedTensorType.getElementType()); @@ -413,9 +439,10 @@ // The implementation proceeds in a stack-like fashion: // 1. Iteratively clone and step into the loops, pushing the `packedTensor` // deeper in the stack. - // 2. Create a InsertSliceOp at the top of the stack. - // 3. Iteratively pop and yield the result of the InsertSliceOp across - // the cloned loops. + // 2. Create a GenericOp if `transposeVector` is non-empty. + // 3. Create a InsertSliceOp at the top of the stack. + // 4. Iteratively pop and yield the result of the InsertSliceOp across + // the cloned loops. SmallVector clonedLoopIvs, leadingPackedTensorIndexings; clonedLoopIvs.reserve(nPackedLoops); leadingPackedTensorIndexings.reserve(nPackedLoops); @@ -455,16 +482,14 @@ packedTensor = clonedForOp.getRegionIterArgs().front(); } - // Stack step 2. create InsertSliceOp at the top of the stack. // offsets = [clonedLoopIvs, 0 .. 0]. SmallVector offsets(leadingPackedTensorIndexings.begin(), leadingPackedTensorIndexings.end()); offsets.append(paddedRank, b.getIndexAttr(0)); - // sizes = [1 .. 1, paddedShape]. + // sizes = [1 .. 1, transposedShape]. SmallVector sizes(nPackedLoops, b.getIndexAttr(1)); - for (int64_t sz : paddedTensorType.getShape()) { + for (int64_t sz : transposedTensorType->getShape()) { // TODO: go grab dims when necessary, for now tensor::PadOp returns a static - // tensor. assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes"); sizes.push_back(b.getIndexAttr(sz)); } @@ -472,11 +497,21 @@ SmallVector strides(nPackedLoops + paddedRank, b.getIndexAttr(1)); - Value inserted = - b.create(loc, bvm.lookup(opToHoist.result()), - packedTensor, offsets, sizes, strides); + // Stack step 2. create GenericOp if `transposeVector` is non-empty. + Value paddedTensor = bvm.lookup(opToHoist.result()); + if (!transposeVector.empty()) { + Value outputTensor = b.create( + loc, *transposedTensorType, packedTensor, offsets, sizes, strides); + transposeOps.push_back( + makeTransposeOp(b, loc, paddedTensor, outputTensor, transposeVector)); + paddedTensor = transposeOps.back()->getResult(0); + } - // Stack step 3. iteratively pop the stack and propagate the yield. + // Stack step 3. create InsertSliceOp at the top of the stack. + Value inserted = b.create( + loc, paddedTensor, packedTensor, offsets, sizes, strides); + + // Stack step 4. iteratively pop the stack and propagate the yield. Value valueToYield = inserted; for (Value iv : llvm::reverse(clonedLoopIvs)) { auto forOp = scf::getForInductionVarOwner(iv); @@ -498,12 +533,22 @@ // offsets = [originalLoopIvs, 0 .. 0]. offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end()); offsets.append(paddedRank, b.getIndexAttr(0)); - // sizes = [1 .. 1, paddedShape] (definedabove). + // sizes = [1 .. 1, transposedShape] (definedabove). // strides = [1 .. 1] (defined above) packedTensor = scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0); Value newResult = b.create( - loc, opToHoist.getResultType(), packedTensor, offsets, sizes, strides); + loc, *transposedTensorType, packedTensor, offsets, sizes, strides); + + // Transpose the packed tensor back to the original storage order. + if (!transposeVector.empty()) { + Value initTensor = + b.create(loc, ValueRange{}, paddedTensorType.getShape(), + paddedTensorType.getElementType()); + transposeOps.push_back( + makeTransposeOp(b, loc, newResult, initTensor, transposeVector)); + newResult = transposeOps.back()->getResult(0); + } // Make the newly cloned `opToHoist` available to the caller. hoistedOp = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -528,20 +528,29 @@ // Hoist the padding. for (const auto &en : enumerate(depths)) { OpOperand &opOperand = paddedOp->getOpOperand(en.index()); - auto padTensorOp = opOperand.get().getDefiningOp(); - if (!padTensorOp || en.value() == 0) + auto padOp = opOperand.get().getDefiningOp(); + if (!padOp || en.value() == 0) continue; tensor::PadOp hoistedOp; - FailureOr newResult = - hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp); + SmallVector transposeOps; + SmallVector transposeVector = + options.paddingTransposeComputationFunction(opOperand); + + FailureOr newResult = hoistPaddingOnTensors( + padOp, en.value(), transposeVector, hoistedOp, transposeOps); if (failed(newResult)) continue; - rewriter.replaceOp(padTensorOp, newResult.getValue()); + rewriter.replaceOp(padOp, newResult.getValue()); + + // Do not apply hoist padding to the newly introduced transpose operations. + for (GenericOp transposeOp : transposeOps) + filter.replaceLinalgTransformationFilter(rewriter, transposeOp); } // Replace the original operation to pad. rewriter.replaceOp(linalgOp, newResults.getValue()); filter.replaceLinalgTransformationFilter(rewriter, paddedOp); + return paddedOp; } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -340,11 +340,11 @@ OpResult opResult = current.cast(); current = linalgOp.getOutputOperand(opResult.getResultNumber())->get(); } - auto padTensorOp = current ? current.getDefiningOp() : nullptr; + auto padOp = current ? current.getDefiningOp() : nullptr; // Exit if the search fails to match a tensor::PadOp at the end of the matched // LinalgOp sequence. - if (!padTensorOp) + if (!padOp) return tensor::createPadHighOp(type, source, pad, nofold, loc, b); // Exit if the padded result type does not match. @@ -352,41 +352,77 @@ return tensor::createPadHighOp(type, source, pad, nofold, loc, b); // Exit if the LinalgOps are not high padded. - if (llvm::any_of(padTensorOp.getMixedLowPad(), [](OpFoldResult ofr) { + if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) { return getConstantIntValue(ofr) != static_cast(0); })) return tensor::createPadHighOp(type, source, pad, nofold, loc, b); - // Exit if `padTensorOpSliceOp`, which defines the slice used by - // `padTensorOp`, is rank-reducing. - auto padTensorOpSliceOp = - padTensorOp.source().getDefiningOp(); - if (!padTensorOpSliceOp || sliceOp.getMixedSizes().size() != - padTensorOpSliceOp.getMixedSizes().size()) + // Exit if `padOpSliceOp`, which defines the slice used by + // `padOp`, is rank-reducing. + auto padOpSliceOp = padOp.source().getDefiningOp(); + if (!padOpSliceOp || + sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) return tensor::createPadHighOp(type, source, pad, nofold, loc, b); // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size - // of the slice padded by `padTensorOp`. - if (llvm::any_of(llvm::zip(sliceOp.getMixedSizes(), - padTensorOpSliceOp.getMixedSizes()), - [](std::tuple it) { - return !isEqualConstantIntOrValue(std::get<0>(it), - std::get<1>(it)); - })) + // of the slice padded by `padOp`. + if (llvm::any_of( + llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()), + [](std::tuple it) { + return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)); + })) return tensor::createPadHighOp(type, source, pad, nofold, loc, b); // Exit if the padding values do not match. - Attribute padTensorOpPadAttr, padAttr; - Value padTensorOpPad = padTensorOp.getConstantPaddingValue(); - if (!padTensorOpPad || - !matchPattern(padTensorOpPad, m_Constant(&padTensorOpPadAttr)) || - !matchPattern(pad, m_Constant(&padAttr)) || padTensorOpPadAttr != padAttr) + Attribute padOpPadAttr, padAttr; + Value padOpPad = padOp.getConstantPaddingValue(); + if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) || + !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr) return tensor::createPadHighOp(type, source, pad, nofold, loc, b); // Return the padded result if the padding values and sizes match. return sliceOp.source(); } +GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, + Value outputTensor, + ArrayRef transposeVector) { + auto resultTensorType = outputTensor.getType().cast(); + Type elementType = resultTensorType.getElementType(); + + assert(isPermutation(transposeVector) && + "expect transpose vector to be a permutation"); + assert(transposeVector.size() == + static_cast(resultTensorType.getRank()) && + "expect transpose vector size to match result tensor rank"); + + // Compute the transpose and the indentity indexing maps. + SmallVector indexingMaps = { + inversePermutation(AffineMap::getPermutationMap( + SmallVector(transposeVector.begin(), transposeVector.end()), + b.getContext())), + AffineMap::getMultiDimIdentityMap(transposeVector.size(), + b.getContext())}; + SmallVector iteratorTypes(transposeVector.size(), + getParallelIteratorTypeName()); + + // Create a GenericOp to transpose `inputTensor` into `outputTensor`. + auto transposeOp = b.create( + loc, resultTensorType, inputTensor, outputTensor, + b.getAffineMapArrayAttr(indexingMaps), b.getStrArrayAttr(iteratorTypes), + /*doc=*/nullptr, + /*library_call=*/nullptr); + Region &body = transposeOp.getRegion(); + body.push_back(new Block()); + body.front().addArguments({elementType, elementType}, {loc, loc}); + + // Create the body of the transpose operation. + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToEnd(&body.front()); + b.create(loc, transposeOp.getRegion().front().getArgument(0)); + return transposeOp; +} + /// Specialization to build an scf "for" nest. template <> void GenerateLoopNest::doit( diff --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir --- a/mlir/test/Dialect/Linalg/hoist-padding.mlir +++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matvec pad hoist-paddings=1,1,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=MATVEC +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matvec pad hoist-paddings=1,1,0 transpose-paddings=1:0,0,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=TRANSP // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad hoist-paddings=1,2,1 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=MATMUL // MATVEC-DAG: #[[DIV4:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 4)> @@ -30,7 +31,7 @@ // MATVEC-DAG: %[[T4:.*]] = tensor.extract_slice %[[T0]][%[[IDX0]] %2 = tensor.extract_slice %arg1[%arg3] [4] [1] : tensor<12xf32> to tensor<4xf32> %3 = tensor.pad %2 nofold low[%c0] high[%c0] { - ^bb0(%arg5: index): + ^bb0(%arg5: index): tensor.yield %cst : f32 } : tensor<4xf32> to tensor<4xf32> @@ -81,11 +82,11 @@ %3 = tensor.extract_slice %arg1[%arg3] [%1] [1] : tensor<12xf32> to tensor %4 = affine.apply #map1(%1) %5 = tensor.pad %2 low[%c0, %c0] high[%c0, %4] { - ^bb0(%arg5: index, %arg6: index): + ^bb0(%arg5: index, %arg6: index): tensor.yield %cst : f32 } : tensor<24x?xf32> to tensor<24x5xf32> %6 = tensor.pad %3 low[%c0] high[%4] { - ^bb0(%arg5: index): + ^bb0(%arg5: index): tensor.yield %cst : f32 } : tensor to tensor<5xf32> @@ -141,11 +142,11 @@ %4 = tensor.extract_slice %arg1[%arg3] [%2] [1] : tensor to tensor %5 = affine.apply #map1(%2) %6 = tensor.pad %3 low[%c0, %c0] high[%c0, %5] { - ^bb0(%arg5: index, %arg6: index): + ^bb0(%arg5: index, %arg6: index): tensor.yield %cst : f32 } : tensor<24x?xf32> to tensor<24x4xf32> %7 = tensor.pad %4 nofold low[%c0] high[%5] { - ^bb0(%arg5: index): + ^bb0(%arg5: index): tensor.yield %cst : f32 } : tensor to tensor<4xf32> @@ -177,7 +178,7 @@ // MATVEC: %[[T1:.*]] = tensor.pad %[[T0]] %2 = tensor.extract_slice %arg1[%arg3] [4] [1] : tensor<12xf32> to tensor<4xf32> %3 = tensor.pad %2 nofold low[%c0] high[%c0] { - ^bb0(%arg5: index): + ^bb0(%arg5: index): %5 = arith.index_cast %arg3 : index to i32 %6 = arith.sitofp %5 : i32 to f32 tensor.yield %6 : f32 @@ -214,7 +215,7 @@ %2 = tensor.extract_slice %arg1[%arg3] [4] [1] : tensor<12xf32> to tensor<4xf32> %3 = tensor.extract %arg1[%arg3] : tensor<12xf32> %4 = tensor.pad %2 nofold low[%c0] high[%c0] { - ^bb0(%arg5: index): + ^bb0(%arg5: index): tensor.yield %3 : f32 } : tensor<4xf32> to tensor<4xf32> @@ -251,7 +252,7 @@ %2 = tensor.extract_slice %arg1[%arg4] [4] [1] : tensor<12xf32> to tensor<4xf32> %3 = arith.index_cast %arg3 : i32 to index %4 = tensor.pad %2 nofold low[%3] high[%3] { - ^bb0(%arg6: index): + ^bb0(%arg6: index): tensor.yield %cst : f32 } : tensor<4xf32> to tensor<4xf32> @@ -288,7 +289,7 @@ %2 = tensor.extract_slice %arg1[%arg4] [4] [1] : tensor<12xf32> to tensor<4xf32> %3 = memref.load %arg3[%c0] : memref %4 = tensor.pad %2 nofold low[%3] high[%3] { - ^bb0(%arg6: index): + ^bb0(%arg6: index): tensor.yield %cst : f32 } : tensor<4xf32> to tensor<4xf32> @@ -328,7 +329,7 @@ scf.yield %6 : index } %4 = tensor.pad %2 nofold low[%3] high[%3] { - ^bb0(%arg6: index): + ^bb0(%arg6: index): tensor.yield %cst : f32 } : tensor<4xf32> to tensor<4xf32> @@ -373,7 +374,7 @@ // Check the fused and padded fill op does not prevent hoisting. %4 = tensor.pad %2 nofold low[%c0, %c0] high[%3, %c0] { - ^bb0(%arg5: index, %arg6: index): + ^bb0(%arg5: index, %arg6: index): tensor.yield %cst : f32 } : tensor to tensor<5x24xf32> %5 = linalg.fill(%cst, %4) : f32, tensor<5x24xf32> -> tensor<5x24xf32> @@ -394,18 +395,18 @@ %10 = tensor.extract_slice %arg1[%arg5, 0] [3, 24] [1, 1] : tensor<6x24xf32> to tensor<3x24xf32> %11 = tensor.extract_slice %arg6[0, 0] [%1, 24] [1, 1] : tensor to tensor %12 = tensor.pad %9 nofold low[%c0, %c0] high[%3, %c0] { - ^bb0(%arg7: index, %arg8: index): + ^bb0(%arg7: index, %arg8: index): tensor.yield %cst : f32 } : tensor to tensor<5x3xf32> %13 = tensor.pad %10 nofold low[%c0, %c0] high[%c0, %c0] { - ^bb0(%arg7: index, %arg8: index): + ^bb0(%arg7: index, %arg8: index): tensor.yield %cst : f32 } : tensor<3x24xf32> to tensor<3x24xf32> // Check the output padding is not hoisted. // MATMUL: %[[T8:.*]] = tensor.pad %14 = tensor.pad %11 nofold low[%c0, %c0] high[%3, %c0] { - ^bb0(%arg7: index, %arg8: index): + ^bb0(%arg7: index, %arg8: index): tensor.yield %cst : f32 } : tensor to tensor<5x24xf32> @@ -421,3 +422,59 @@ } return %0 : tensor<12x24xf32> } + +// ----- + +#map0 = affine_map<(d0)[s0] -> (4, -d0 + s0)> +#map1 = affine_map<(d0) -> (-d0 + 4)> + +// TRANSP: transpose +// TRANSP-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x?xf32> +func @transpose(%arg0: tensor<24x?xf32>, + %arg1: tensor, + %arg2: tensor<24xf32>) -> tensor<24xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = tensor.dim %arg0, %c1 : tensor<24x?xf32> + + // Transpose the padded matrix. + // TRANSP: %[[T0:.*]] = scf.for %[[PIV0:[0-9a-z]+]] = {{.*}}iter_args(%[[T1:.*]] = + // TRANSP: %[[T2:.*]] = tensor.pad + // TRANSP: %[[T3:.*]] = tensor.extract_slice %[[T1]] + // TRANSP: %[[T4:.*]] = linalg.generic + // TRANSP-SAME: ins(%[[T2]] : tensor<24x4xf32> + // TRANSP-SAME: outs(%[[T3]] : tensor<4x24xf32> + // TRANSP: %[[T5:.*]] = tensor.insert_slice %[[T4]] into %[[T1]] + // TRANSP: scf.yield %[[T5:.*]] + + // TRANSP: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %1 = scf.for %arg3 = %c0 to %0 step %c4 iter_args(%arg4 = %arg2) -> (tensor<24xf32>) { + %2 = affine.min #map0(%arg3)[%0] + %3 = tensor.extract_slice %arg0[0, %arg3] [24, %2] [1, 1] : tensor<24x?xf32> to tensor<24x?xf32> + + // Index the packed vector and transpose back. + // TRANSP: %[[T6:.*]] = tensor.extract_slice %[[T0]] + // TRANSP: %[[T7:.*]] = linalg.init_tensor + // TRANSP: %[[T8:.*]] = linalg.generic + // TRANSP-SAME: ins(%[[T6]] : tensor<4x24xf32> + // TRANSP-SAME: outs(%[[T7]] : tensor<24x4xf32> + %4 = tensor.extract_slice %arg1[%arg3] [%2] [1] : tensor to tensor + %5 = affine.apply #map1(%2) + %6 = tensor.pad %3 low[%c0, %c0] high[%c0, %5] { + ^bb0(%arg5: index, %arg6: index): // no predecessors + tensor.yield %cst : f32 + } : tensor<24x?xf32> to tensor<24x4xf32> + %7 = tensor.pad %4 nofold low[%c0] high[%5] { + ^bb0(%arg5: index): // no predecessors + tensor.yield %cst : f32 + } : tensor to tensor<4xf32> + + // Check matvec uses the packed input vector. + // TRANSP: = linalg.matvec ins(%[[T8]] + %8 = linalg.matvec ins(%6, %7 : tensor<24x4xf32>, tensor<4xf32>) outs(%arg4 : tensor<24xf32>) -> tensor<24xf32> + scf.yield %8 : tensor<24xf32> + } + return %1 : tensor<24xf32> +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -109,6 +109,17 @@ *this, "hoist-paddings", llvm::cl::desc("Operand hoisting depths when test-pad-pattern."), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption transposePaddings{ + *this, "transpose-paddings", + llvm::cl::desc( + "Transpose paddings when test-pad-pattern. Specify a " + "operand dimension interchange using the following format:\n" + "-transpose-paddings=1:0:2,0:1,0:1\n" + "It defines the interchange [1, 0, 2] for operand one and " + "the interchange [0, 1] (no transpose) for the remaining operands." + "All interchange vectors have to be permuations matching the " + "operand rank."), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; Option generalize{*this, "generalize", llvm::cl::desc("Generalize named operations."), llvm::cl::init(false)}; @@ -257,9 +268,21 @@ ? hoistPaddings[opOperand.getOperandNumber()] : 0; }; + auto transposeFunc = [&](OpOperand &opOperand) { + SmallVector transposeVector = {}; + if (opOperand.getOperandNumber() >= transposePaddings.size()) + return transposeVector; + SmallVector elems; + StringRef(transposePaddings[opOperand.getOperandNumber()]) + .split(elems, ':'); + for (StringRef elem : elems) + transposeVector.push_back(std::stoi(elem.str())); + return transposeVector; + }; paddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp); paddingOptions.setPaddingNoFoldComputationFunction(packFunc); paddingOptions.setPaddingHoistComputationFunction(hoistingFunc); + paddingOptions.setPaddingTransposeComputationFunction(transposeFunc); // Compute input padding values only an return failure for output operands. if (padInputsOnly) {