diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -773,6 +773,9 @@ LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes); + /// Return the index in the indexingMaps vector that corresponds to this `opOperand` + int64_t getIndexingMapIndex(OpOperand *opOperand); + //========================================================================// // Forwarding functions to access interface methods from the // DestinationStyleOpInterface. diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -363,8 +363,11 @@ }]; } +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// def PackOp : Op, DeclareOpInterfaceMethods,]> { let description = [{ Pack a LinalgOp by applying a data tiling transformation on the op and @@ -439,14 +442,73 @@ }]; let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure apply( - transform::TransformResults &transformResults, - transform::TransformState &state); - ::llvm::SmallVector<::mlir::OpFoldResult> getMixedPackedSizes(); }]; } +//===----------------------------------------------------------------------===// +// PackTransposeOp +//===----------------------------------------------------------------------===// +def PackTransposeOp : Op]> { + let description = [{ + Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and + update the `linalg.generic` op that consumes (resp. produces) the operation. + + This transform allows composing a simple `structured.pack` with additional + transpositions to e.g. match the data format required by a specific library + call or ISA instruction. + + The transpose spec must specify at least one of `outer_perm` or `inner_perm` + attributes, which will act upon the `outer_dims_perm` or `inner_dims_pos` of + the specified `tensor.pack` or `tensor.unpack` op. + + If the `target` of this op is a `tensor.pack` then a new `tensor.empty` will + be created along with transposed versions of the `tensor.pack` and the + consuming `linalg.generic`, which is expected to be the sole consumer. + + If the `target` of this op is a `tensor.unpack` then the whole pack / compute + / unpack chain will be transposed and transposed clones of `tensor.pack`, + the consuming `linalg.generic` and the tail `tensor.pack` will be created. + + #### Return modes + + This operation targets a single `tensor.pack` / `tensor.unpack` op and a + single matching `linalg.generic` that consumes / produces the op. Otherwise, + it produces a silenceableFailure. + + This operation may produce a silenceableFailure if the transpose spec is + ill-formed (i.e. `outer_perm` or `inner_perm` are not permutations of the + proper rank) or if the tranposition of all involved operations fails for any + reason. + + This operation returns 3 handles, one to the transformed LinalgOp, one to + the transformed `tensor.pack` and one to the transformed `tensor.unpack`. + The last handle for `tensor.unpack` is empty if `target_pack_or_unpack_op` + was not itself a `tensor.unpack`. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target_pack_or_un_pack_op, + TransformHandleTypeInterface:$target_linalg_op, + DefaultValuedOptionalAttr:$outer_perm, + DefaultValuedOptionalAttr:$inner_perm); + let results = (outs TransformHandleTypeInterface:$packed_op, + TransformHandleTypeInterface:$pack_op, + TransformHandleTypeInterface:$un_pack_op); + let assemblyFormat = [{ + $target_pack_or_un_pack_op + `with_compute_op` `(` $target_linalg_op `)` + (`outer_perm` `=` $outer_perm^ )? + (`inner_perm` `=` $inner_perm^ )? + attr-dict + `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // PadOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1776,6 +1776,21 @@ static Value createDestinationTensor(OpBuilder &b, Location loc, Value source, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm); + + /// Build and return a new PackOp that is a clone of the current PackOp with + /// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by + /// innerPermutation (resp. outerPermutation). + /// A new `tensor.empty` of the proper shape is built in the process. + /// Asserts that: + /// - At least one of innerPermutation or outerPermutation is non-empty. + /// - If not empty, innerPermutation is a valid permutation of size + /// matching innerDimPos. + /// - If not empty, outerPermutation is a valid permutation of size + /// matching outerDimsPerm. + PackOp createTransposedClone(OpBuilder &b, + Location loc, + ArrayRef innerPermutation, + ArrayRef outerPermutation); }]; let hasCanonicalizeMethod = 1; @@ -1832,7 +1847,23 @@ CArg<"ArrayRef", "{}">:$outerDimsPerm)> ]; - let extraClassDeclaration = commonExtraClassDeclaration; + let extraClassDeclaration = commonExtraClassDeclaration # [{ + /// Build and return a new UnPackOp that is a clone of the current UnPackOp + /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by + /// innerPermutation (resp. outerPermutation). + /// Asserts that: + /// - At least one of innerPermutation or outerPermutation is non-empty. + /// - If not empty, innerPermutation is a valid permutation of size + /// matching innerDimPos. + /// - If not empty, outerPermutation is a valid permutation of size + /// matching outerDimsPerm. + UnPackOp createTransposedClone(OpBuilder &b, + Location loc, + Value transposedSource, + ArrayRef innerPermutation, + ArrayRef outerPermutation); + }]; + let hasCanonicalizeMethod = 1; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -621,6 +621,22 @@ return success(); } +/// Return the index in the indexingMaps vector that corresponds to this +/// `opOperand`. +int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { + auto operandNumber = opOperand->getOperandNumber(); + auto dpsIface = cast(*this->getOperation()); + if (!dpsIface.isDpsInput(opOperand)) + return operandNumber; + auto [start, end] = dpsIface.getDpsInitsPositionRange(); + assert(!dpsIface.isDpsInit(opOperand)); + // Account for potential inputs that are not DPS and may not appear in + // `indexingMaps`. + return cast(*this->getOperation()) + .getNumDpsInputs() + + operandNumber - start; +} + LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -17,17 +17,21 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/Dialect/Transform/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" @@ -1161,16 +1165,12 @@ // Fail on multi-op handles. auto linalgOp = dyn_cast(targetOps.front()); if (targetOps.size() != 1 || !linalgOp) { - // TODO: remove this unnecessary set to empty once crashes are fixed. - transformResults.set(getPackedOp().cast(), {}); return emitSilenceableError() << "requires target to map to exactly 1 LinalgOp (got " << targetOps.size() << ")"; } // Fail on mismatched number of pack sizes. if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) { - // TODO: remove this unnecessary set to empty once crashes are fixed. - transformResults.set(getPackedOp().cast(), {}); return emitSilenceableError() << "requires number of packed sizes match the number of loops (" << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops() @@ -1194,6 +1194,263 @@ transform::consumesHandle(getTarget(), effects); transform::onlyReadsHandle(getPackedSizes(), effects); transform::producesHandle(getPackedOp(), effects); + transform::modifiesPayload(effects); +} + +//===---------------------------------------------------------------------===// +// PackTransposeOp +//===---------------------------------------------------------------------===// + +namespace { +enum class OuterOrInnerPerm { Outer = 0, Inner = 1 }; +} // namespace + +/// Return true if `permutation` is a valid permutation of the `outer_dims_perm` +/// (case OuterOrInnerPerm::Outer) or `inner_dims_pos` (OuterOrInnerPerm::Inner) +/// of the `tensor.pack` or `tensor.unpack` `op. +/// This is the case when the `permutation` rank matches the rank expected by +/// `op` and `permutation` is itself a permutation vector. +/// Return true if either `op` or `permutation` are empty to allow a simpler +/// polymorphic implementation. +template +bool isValidPackingPermutation( + RelayoutOpTy op, ArrayRef permutation, + OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) { + static_assert( + llvm::is_one_of::value, + "applies to only pack or unpack operations"); + if (!op || permutation.empty()) + return true; + int64_t innerRank = op.getInnerDimsPos().size(); + if (outerOrInnerPerm == OuterOrInnerPerm::Inner) + return permutation.size() == innerRank && isPermutationVector(permutation); + // op.getOuterDimsPerm() may be empty, in which case it is identity. + // Don't rely on it. + if (std::is_same::value) { + return permutation.size() == op.getSourceRank() && + isPermutationVector(permutation); + } + return permutation.size() == op.getDestRank() && + isPermutationVector(permutation); +} + +/// Return a copy of `tensorType` after permutation by `permutationVector`. +// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder +// but this would introduce a dependence on Dialect in IR. +// TODO: Restructure. +static RankedTensorType permuteShape(RankedTensorType tensorType, + ArrayRef permutationVector) { + SmallVector shape(tensorType.getShape()); + applyPermutationToVector(shape, permutationVector); + return RankedTensorType::Builder(tensorType).setShape(shape); +} + +/// Return a new GenericOp obtained by transposing opOperand by the permutation +/// vector: +/// - the corresponding indexing map is transposed by `permutation` +/// - the corresponding operand value is replaced by `transposedValue` +/// `linalgOp` is replaced by the return op in the process. +/// Asserts that `transposedValue` is of the proper transposed ShapedType. +static LinalgOp transposeOneLinalgOperandAndReplace( + RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, + ArrayRef permutation, Value transposedValue) { + // Sanity check the operand. + assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand"); + + // Sanity check of the expected transposed tensor type. + auto tensorType = permuteShape( + opOperand.get().getType().cast(), permutation); + assert(tensorType == transposedValue.getType() && + "expected tensor type mismatch"); + + // Compute the transposed indexing map. + // Sigh unsigned pollution. + SmallVector tmpTransposition = llvm::to_vector( + llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; })); + AffineMap permutationMap = + AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext()); + AffineMap transposedMap = + permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand)); + + // Set the transposed indexing map in the proper position. + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; + // Set the transposedValue in the proper operand position. + SmallVector operands = linalgOp->getOperands(); + operands[opOperand.getOperandNumber()] = transposedValue; + + ValueRange operandsRef(operands); + auto transposedGenericOp = rewriter.create( + /*location=*/linalgOp->getLoc(), + /*resultTensorTypes=*/ + operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(), + /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()), + /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); + transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0)); + rewriter.replaceOp(linalgOp, transposedGenericOp->getResults()); + + return cast(transposedGenericOp.getOperation()); +} + +LogicalResult transform::PackTransposeOp::verify() { + if (!isPermutationVector(getInnerPerm())) { + return emitOpError() << getInnerPermAttrName() + << " is not a valid permutation"; + } + if (!isPermutationVector(getOuterPerm())) { + return emitOpError() << getOuterPermAttrName() + << " is not a valid permutation"; + } + if (getInnerPerm().empty() && getOuterPerm().empty()) { + return emitOpError() << " at least one of " << getInnerPermAttrName() + << " or " << getOuterPermAttrName() + << " must be specified"; + } + return success(); +} + +DiagnosedSilenceableFailure +transform::PackTransposeOp::apply(transform::TransformResults &transformResults, + transform::TransformState &state) { + ArrayRef packOrUnpackOps = + state.getPayloadOps(getTargetPackOrUnPackOp()); + ArrayRef linalgOps = state.getPayloadOps(getTargetLinalgOp()); + // Step 1. If nothing to pack, propagate success. + if (packOrUnpackOps.empty()) { + transformResults.set(getPackedOp().cast(), {}); + transformResults.set(getPackOp().cast(), {}); + transformResults.set(getUnPackOp().cast(), {}); + return DiagnosedSilenceableFailure::success(); + } + + // Step 2. Bunch of runtime sanity check and error messages. + // Step 2.1. Fail on multi-op handles. + if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) { + return emitSilenceableError() + << "requires target to map to exactly 1 packing op and 1 packed op (" + << "got " << packOrUnpackOps.size() << " and " << linalgOps.size() + << ")"; + } + + // Step 2.2. Fail on wrong type. + auto packOp = dyn_cast(packOrUnpackOps.front()); + auto unPackOp = dyn_cast(packOrUnpackOps.front()); + if ((!packOp && !unPackOp)) { + return emitSilenceableError() << "requires target to map to a " + "tensor.pack or tensor.unpack"; + } + LinalgOp linalgOpTarget = dyn_cast(linalgOps.front()); + if (!linalgOpTarget) + return emitSilenceableError() << "requires a LinalgOp target"; + + // Step 2.3. Fail if we can't get the producer / consumer Linalg op. + LinalgOp linalgOp; + if (packOp && packOp.getResult().hasOneUse()) + linalgOp = dyn_cast(*(packOp.getResult().getUsers().begin())); + else if (unPackOp) + linalgOp = unPackOp.getSource().getDefiningOp(); + if (linalgOp != linalgOpTarget) { + auto errorMsg = + packOp ? StringLiteral{"not a single use by the LinalgOp target"} + : StringLiteral{"not produced by the LinalgOp target"}; + return emitSilenceableError() << errorMsg; + } + + // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical PackOp. + if (unPackOp) { + assert(!packOp && "packOp must be null on entry when unPackOp is not null"); + OpOperand *packUse = linalgOp.getDpsInitOperand( + unPackOp.getSource().cast().getResultNumber()); + packOp = dyn_cast_or_null(packUse->get().getDefiningOp()); + if (!packOp || !packOp.getResult().hasOneUse()) + return emitSilenceableError() << "could not find matching pack op"; + } + + // Step 2.5. Fail if any permutation does not validate. + for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) { + ArrayRef perm = + (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm(); + auto errorMsg = (permType == OuterOrInnerPerm::Outer) + ? StringLiteral{"invalid outer_perm"} + : StringLiteral{"invalid inner_perm"}; + if (!isValidPackingPermutation(packOp, perm, permType) || + !isValidPackingPermutation(unPackOp, perm, permType)) { + Operation *packOrUnpackOp = + unPackOp ? unPackOp.getOperation() : packOp.getOperation(); + return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp; + } + } + + // From here on, packOp and linalgOp are always present, unPackOp may or may + // not be present. + assert(packOp && linalgOp && "unexpected null op"); + + // Step 3. Actually transpose the ops. + Location loc = linalgOp.getLoc(); + IRRewriter rewriter(getContext()); + + // Step 3.a. Transpose packOp. + rewriter.setInsertionPoint(packOp); + tensor::PackOp transposedPackOp = packOp.createTransposedClone( + rewriter, loc, getInnerPerm(), getOuterPerm()); + + // Step 3.b. Transpose linalgOp. + assert(packOp.getResult().hasOneUse() && "expect single use"); + // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the + // identity. Don't rely on it. + int64_t numLeadingDims = packOp.getSourceRank(); + int64_t numTrailingDims = packOp.getInnerDimsPos().size(); + // Step 3.b.i. Compute the permutation on the whole operand. + // Leading part just reuse the outerPerm. + SmallVector permutation(getOuterPerm()); + if (permutation.empty()) + llvm::append_range(permutation, llvm::seq(0, numLeadingDims)); + // Trailing part needs to reindex positions by `numLeadingDims`. + if (getInnerPerm().empty()) { + llvm::append_range( + permutation, + llvm::seq(numLeadingDims, numLeadingDims + numTrailingDims)); + } else { + llvm::append_range(permutation, + llvm::map_range(getInnerPerm(), [&](int64_t pos) { + return numLeadingDims + pos; + })); + } + assert(isPermutationVector(permutation) && "invalid permutation"); + // Step 3.b.ii. Save the transposedPackUse operand number in case we need to + // get the tied OpResult after `linalgOp` has been replaced. + OpOperand &packUse = *(packOp.getResult().getUses().begin()); + int64_t packUseOperandNumber = packUse.getOperandNumber(); + // Step 3.b.iii. Actually perform the transposition. + rewriter.setInsertionPoint(linalgOp); + linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace( + rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); + + // Step 3.c. Maybe transpose unPackOp. + tensor::UnPackOp transposedUnPackOp; + if (unPackOp) { + OpOperand &opOperand = + transposedLinalgOp->getOpOperand(packUseOperandNumber); + OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand); + rewriter.setInsertionPoint(unPackOp); + transposedUnPackOp = unPackOp.createTransposedClone( + rewriter, loc, transposedResult, getInnerPerm(), getOuterPerm()); + } + + // Step 4. Replace and return results. + rewriter.replaceOp(packOp, transposedPackOp->getResults()); + transformResults.set(getPackOp().cast(), {transposedPackOp}); + // transposedLinalgOp was replaced in `transposeOneLinalgOperandAndReplace`. + transformResults.set(getPackedOp().cast(), {transposedLinalgOp}); + if (unPackOp) { + rewriter.replaceOp(unPackOp, transposedUnPackOp->getResults()); + transformResults.set(getUnPackOp().cast(), {transposedUnPackOp}); + } else { + transformResults.set(getUnPackOp().cast(), {}); + } + return DiagnosedSilenceableFailure::success(); } //===---------------------------------------------------------------------===// @@ -1359,7 +1616,7 @@ if (!target->hasTrait() && target->getNumRegions() > 0) return emitDefiniteFailure() - << "expected target that is isloated from above"; + << "expected target that is isolated from above"; } // Clone and replace. @@ -1907,32 +2164,31 @@ scf::SCFTilingOptions tilingOptions; unsigned index = en.index(); if (!tileSizes.empty()) { - tilingOptions.setTileSizeComputationFunction( - [&, index](OpBuilder &b, Operation *) { - SmallVector sizes; - sizes.reserve(tileSizes.size()); - unsigned dynamicIdx = 0; - for (OpFoldResult ofr : getMixedSizes()) { - if (auto attr = ofr.dyn_cast()) { - sizes.push_back(b.create( - getLoc(), attr.cast().getInt())); - continue; - } - ArrayRef dynamicSizes = - dynamicSizeProducers[dynamicIdx]; - ArrayRef params = paramSizes[dynamicIdx]; - ++dynamicIdx; - assert((dynamicSizes.empty() ^ params.empty()) && - "expected either dynamic sizes or parameters"); - if (!params.empty()) { - sizes.push_back( - b.create(getLoc(), params[index])); - } else { - sizes.push_back(dynamicSizes[index]->getResult(0)); - } - } - return sizes; - }); + tilingOptions.setTileSizeComputationFunction([&, index](OpBuilder &b, + Operation *) { + SmallVector sizes; + sizes.reserve(tileSizes.size()); + unsigned dynamicIdx = 0; + for (OpFoldResult ofr : getMixedSizes()) { + if (auto attr = ofr.dyn_cast()) { + sizes.push_back(b.create( + getLoc(), attr.cast().getInt())); + continue; + } + ArrayRef dynamicSizes = dynamicSizeProducers[dynamicIdx]; + ArrayRef params = paramSizes[dynamicIdx]; + ++dynamicIdx; + assert((dynamicSizes.empty() ^ params.empty()) && + "expected either dynamic sizes or parameters"); + if (!params.empty()) { + sizes.push_back( + b.create(getLoc(), params[index])); + } else { + sizes.push_back(dynamicSizes[index]->getResult(0)); + } + } + return sizes; + }); } tilingOptions.setInterchange(getInterchange()); @@ -2149,27 +2405,27 @@ // Transform all targets one by one. for (Operation *target : targets) { - auto tilableOp = dyn_cast(target); - if (!tilableOp) { + auto tileableOp = dyn_cast(target); + if (!tileableOp) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "only TilingInterface ops are supported"; diag.attachNote(target->getLoc()) << "target op"; return diag; } - rewriter.setInsertionPoint(tilableOp); + rewriter.setInsertionPoint(tileableOp); FailureOr tilingResult = failure(); if (!mixedNumThreads.empty()) { - tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, + tilingResult = linalg::tileToForeachThreadOp(rewriter, tileableOp, mixedNumThreads, mapping); } else { tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( - rewriter, tilableOp, mixedTileSizes, mapping); + rewriter, tileableOp, mixedTileSizes, mapping); } if (failed(tilingResult)) - return transformOp.emitDefaultSilenceableFailure(tilableOp); - rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults()); + return transformOp.emitDefaultSilenceableFailure(tileableOp); + rewriter.replaceOp(tileableOp, tilingResult->tileOp->getResults()); tileOps.push_back(tilingResult->tileOp); tiledOps.push_back(tilingResult->tiledOp); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3231,7 +3231,6 @@ return true; } return shape == constTileSize.value(); - })) { return op->emitError("mismatch in inner tile sizes specified and shaped of " "tiled dimension in the packed type"); @@ -3239,6 +3238,57 @@ return success(); } +namespace { +/// Subset of PackOp/UnPackOp fields used to compute the result of applying +/// various permutations to the op. +// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse +// these. These may or may not become true foldings / canonicalizations +// depending on how aggressive we want to be in automatically folding +// transposes. +struct PackOrUnPackTransposeResult { + SmallVector innerDimsPos; + SmallVector innerTiles; + SmallVector outerDimsPerm; +}; +} // namespace + +template +static PackOrUnPackTransposeResult +commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, + ArrayRef innerPermutation, + ArrayRef outerPermutation) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + assert((!innerPermutation.empty() || !outerPermutation.empty()) && + "some permutation must be non-empty"); + PackOrUnPackTransposeResult metadata; + metadata.innerDimsPos = + SmallVector(packOrUnPackOp.getInnerDimsPos()); + metadata.innerTiles = + SmallVector(packOrUnPackOp.getMixedTiles()); + int64_t numOuterDims = std::is_same::value + ? packOrUnPackOp.getSourceRank() + : packOrUnPackOp.getDestRank(); + metadata.outerDimsPerm = + packOrUnPackOp.getOuterDimsPerm().empty() + ? llvm::to_vector(llvm::seq(0, numOuterDims)) + : SmallVector(packOrUnPackOp.getOuterDimsPerm()); + if (!innerPermutation.empty()) { + assert(innerPermutation.size() == metadata.innerDimsPos.size() && + isPermutationVector(innerPermutation) && + "invalid inner permutation"); + applyPermutationToVector(metadata.innerDimsPos, innerPermutation); + applyPermutationToVector(metadata.innerTiles, innerPermutation); + } + if (!outerPermutation.empty()) { + assert(outerPermutation.size() == metadata.outerDimsPerm.size() && + isPermutationVector(outerPermutation) && + "invalid outer permutation"); + applyPermutationToVector(metadata.outerDimsPerm, outerPermutation); + } + return metadata; +} + //===----------------------------------------------------------------------===// // PackOp //===----------------------------------------------------------------------===// @@ -3386,6 +3436,19 @@ return b.create(loc, mixedSizes, elemType); } +PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, + ArrayRef innerPermutation, + ArrayRef outerPermutation) { + PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp( + *this, innerPermutation, outerPermutation); + Value transposedDest = + createDestinationTensor(b, loc, getSource(), metadata.innerTiles, + metadata.innerDimsPos, metadata.outerDimsPerm); + return b.create(loc, getSource(), transposedDest, + metadata.innerDimsPos, metadata.innerTiles, + getPaddingValue(), metadata.outerDimsPerm); +} + /// Returns true if the tiles and the tiled dims are constant. template bool areTilesAndTiledDimsAllConstant(OpTy op) { @@ -3508,6 +3571,17 @@ builder.getDenseI64ArrayAttr(staticTileSizes)); } +UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc, + Value transposedSource, + ArrayRef innerPermutation, + ArrayRef outerPermutation) { + PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp( + *this, innerPermutation, outerPermutation); + return b.create(loc, transposedSource, getDest(), + metadata.innerDimsPos, metadata.innerTiles, + metadata.outerDimsPerm); +} + /// pack(unpack(x)) -> x LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, PatternRewriter &rewriter) { diff --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir @@ -49,21 +49,21 @@ iterator_types = ["reduction", "parallel"] } -// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> // CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d1)> // CHECK-LABEL: @col_reduction_2d_static // CHECK-SAME: %[[T0:.+]]: tensor<7x3xf16>, // CHECK-SAME: %[[T1:.+]]: tensor<3xf16> func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> { - // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x4xf16> + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<3x2x4xf16> // CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) - // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<2x3x4xf16> + // CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<3x2x4xf16> // CHECK-NOT: tensor.pack // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]] // CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"] - // CHECK-SAME: ins(%{{.*}} : tensor<2x3x4xf16>) + // CHECK-SAME: ins(%{{.*}} : tensor<3x2x4xf16>) // CHECK-SAME: outs(%{{.*}} : tensor<3xf16>) %2 = linalg.generic #col_reduction_2d_trait ins(%t0 : tensor<7x3xf16>) outs(%t1 : tensor<3xf16>) { ^bb0(%in: f16, %out: f16): @@ -78,8 +78,15 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - transform.structured.pack %0 packed_sizes = [4, 0] + %1 = transform.structured.pack %0 packed_sizes = [4, 0] : (!pdl.operation) -> (!transform.op<"linalg.generic">) + %pack = transform.get_producer_of_operand %1[0] + : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">) + %2, %pack_2, %empty_unpack_2 = + transform.structured.pack_transpose %pack with_compute_op(%1) + outer_perm = [1, 0] + : (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">) + -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !pdl.operation) } // ----- @@ -183,7 +190,7 @@ // K N n k // CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> // M N m n -// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d4, d3)> // CHECK-LABEL: @matmul // CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: tensor, @@ -196,19 +203,19 @@ // CHECK-SAME: : tensor -> tensor // CHECK: %[[PACK_B:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [3, 4] // CHECK-SAME: : tensor -> tensor - // CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3] - // CHECK-SAME: : tensor -> tensor + // CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2] + // CHECK-SAME: : tensor -> tensor // CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} // CHECK-SAME: ins(%{{.*}} : tensor, tensor) - // CHECK-SAME: outs(%{{.*}} : tensor) + // CHECK-SAME: outs(%{{.*}} : tensor) %0 = linalg.matmul ins(%A, %B: tensor, tensor) outs(%C: tensor) -> tensor - // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3] - // CHECK-SAME: : tensor -> tensor + // CHECK: tensor.unpack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2] + // CHECK-SAME: : tensor -> tensor return %0 : tensor } @@ -218,6 +225,14 @@ // M N K %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4] : (!pdl.operation) -> (!transform.op<"linalg.generic">) + + %unpack = transform.get_consumers_of_result %1[0] + : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">) + %2, %pack_2, %unpack_2 = + transform.structured.pack_transpose %unpack with_compute_op(%1) + outer_perm = [1, 0] inner_perm = [1, 0] + : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">) + -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">) } // ----- @@ -404,3 +419,177 @@ %1 = transform.structured.pack %0 packed_sizes = [2, 3] : (!pdl.operation) -> (!transform.op<"linalg.generic">) } + +// ----- + +func.func @no_single_packing_op(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) { + %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32> + %1 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32> + %2 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32> + return +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1 = transform.structured.match ops{["tensor.unpack"]} in %arg1 + // expected-error @below {{requires target to map to exactly 1 packing op and 1 packed op (got 2 and 1)}} + transform.structured.pack_transpose %0 with_compute_op(%1) + inner_perm = [0] + : (!pdl.operation, !pdl.operation) + -> (!pdl.operation, !pdl.operation, !pdl.operation) +} + +// ----- + +func.func @no_single_pack_unpack(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) { + %0 = arith.constant 0 : index + %1 = tensor.empty() : tensor + return +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["arith.constant"]} in %arg1 + %1 = transform.structured.match ops{["tensor.empty"]} in %arg1 + // expected-error @below {{requires target to map to a tensor.pack or tensor.unpack}} + transform.structured.pack_transpose %0 with_compute_op(%1) + inner_perm = [0] + : (!pdl.operation, !pdl.operation) + -> (!pdl.operation, !pdl.operation, !pdl.operation) +} + +// ----- + +func.func @no_linalg_target(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) { + %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32> + %1 = arith.constant 0 : index + return +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1 = transform.structured.match ops{["arith.constant"]} in %arg1 + // expected-error @below {{requires a LinalgOp target}} + transform.structured.pack_transpose %0 with_compute_op(%1) + inner_perm = [0] + : (!pdl.operation, !pdl.operation) + -> (!pdl.operation, !pdl.operation, !pdl.operation) +} + +// ----- + +func.func @no_single_use_by_linalg(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) { + %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32> + %f0 = arith.constant 0.0 : f32 + %1 = tensor.empty() : tensor + %2 = linalg.fill ins(%f0: f32) outs(%1 : tensor) -> tensor + return +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1 = transform.structured.match ops{["linalg.fill"]} in %arg1 + // expected-error @below {{not a single use by the LinalgOp target}} + transform.structured.pack_transpose %0 with_compute_op(%1) + inner_perm = [0] + : (!pdl.operation, !pdl.operation) + -> (!pdl.operation, !pdl.operation, !pdl.operation) +} + +// ----- + +func.func @not_produced_by_linalg(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) { + %a = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32> + %b = tensor.unpack %a inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32> + %f0 = arith.constant 0.0 : f32 + %1 = tensor.empty() : tensor + %2 = linalg.fill ins(%f0: f32) outs(%1 : tensor) -> tensor + return +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 + %1 = transform.structured.match ops{["linalg.fill"]} in %arg1 + // expected-error @below {{not produced by the LinalgOp target}} + transform.structured.pack_transpose %0 with_compute_op(%1) + inner_perm = [0] + : (!pdl.operation, !pdl.operation) + -> (!pdl.operation, !pdl.operation, !pdl.operation) +} + +// ----- + +func.func @no_matching_pack(%source: tensor<16xf32>) { + %f0 = arith.constant 0.0 : f32 + %1 = tensor.empty() : tensor<4x4xf32> + %2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32> + %b = tensor.unpack %2 inner_dims_pos = [0] inner_tiles = [4] into %source : tensor<4x4xf32> -> tensor<16xf32> + return +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 + %1 = transform.structured.match ops{["linalg.fill"]} in %arg1 + // expected-error @below {{could not find matching pack op}} + transform.structured.pack_transpose %0 with_compute_op(%1) + inner_perm = [0] + : (!pdl.operation, !pdl.operation) + -> (!pdl.operation, !pdl.operation, !pdl.operation) +} + +// ----- + +func.func @invalid_outer_perm(%A: tensor, %B: tensor, %C: tensor) + -> tensor { + %0 = linalg.matmul ins(%A, %B: tensor, tensor) + outs(%C: tensor) + -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) + + %unpack = transform.get_consumers_of_result %1[0] + : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">) + %2, %pack_2, %unpack_2 = + // expected-error @below {{invalid outer_perm}} + transform.structured.pack_transpose %unpack with_compute_op(%1) + outer_perm = [1] + : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">) + -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">) +} + +// ----- + +func.func @invalid_inner_perm(%A: tensor, %B: tensor, %C: tensor) + -> tensor { + %0 = linalg.matmul ins(%A, %B: tensor, tensor) + outs(%C: tensor) + -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4] + : (!pdl.operation) -> (!transform.op<"linalg.generic">) + + %unpack = transform.get_consumers_of_result %1[0] + : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">) + %2, %pack_2, %unpack_2 = + // expected-error @below {{invalid inner_perm}} + transform.structured.pack_transpose %unpack with_compute_op(%1) + inner_perm = [1] + : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">) + -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">) +}