diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -428,6 +428,27 @@ result.getResultNumber()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the index in the indexingMaps vector that corresponds to this `opOperand` + }], + /*retTy=*/"int64_t", + /*methodName=*/"getIndexingMapIndex", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto operandNumber = opOperand->getOperandNumber(); + auto dpsIface = cast(*this->getOperation()); + if (!dpsIface.isDpsInput(opOperand)) return operandNumber; + auto [start, end] = dpsIface.getDpsInitsPositionRange(); + assert(!$_op.isDpsInit(opOperand)); + // Account for potential inputs that are not DPS and may not appear in + // `indexingMaps`. + return cast(*this->getOperation()) + .getNumDpsInputs() + + operandNumber - start; + }] + >, InterfaceMethod< /*desc=*/[{ Return the value yielded by the region corresponding to an output diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -341,6 +341,9 @@ }]; } +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// def PackOp : Op,]> { @@ -425,6 +428,74 @@ }]; } +//===----------------------------------------------------------------------===// +// PackTransposeOp +//===----------------------------------------------------------------------===// +def PackTransposeOp : Op,]> { + let description = [{ + Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and + updates the `linalg.generic` op that consumes (resp. produces) the operation. + + This transform allows composing a simple `structured.pack` with additional + transpositions to e.g. match the data format required by a specific library + call or ISA instruction. + + The transpose spec must specify at least one of `outer_perm` or `inner_perm` + attributes, which will act upon the `outer_dims_perm` or `inner_dims_pos` of + the specified `tensor.pack` or `tensor.unpack` op. + + If the `target` of this op is a `tensor.pack` then a new `tensor.empty` will + be created along with transposed versions of the `tensor.pack` and the + consuming `linalg.generic`, which is expected to be the sole consumer. + + If the `target` of this op is a `tensor.unpack` then the whole pack / compute + / unpack chain will be transposed and transposed clones of `tensor.pack`, + the consuming `linalg.generic` and the tail `tensor.pack` will be created. + + #### Return modes + + This operation targets a single `tensor.pack` / `tensor.unpack` op and a + single matching `linalg.generic` that consumes / produces the op. Otherwise, + it produces a silenceableFailure. + + This operation may produce a silenceableFailure if the transpose spec is + ill-formed (i.e. `outer_perm` or `inner_perm` are not permutations of the + proper rank) or if the tranposition of all involved operations fails for any + reason. + + This operation returns 3 handles, one to the transformed LinalgOp, one to + the transformed `tensor.pack` and one to the transformed `tensor.unpack`. + The last handle for `tensor.unpack` is empty if `target_pack_or_unpack_op` + was not itself a `tensor.unpack`. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target_pack_or_unpack_op, + TransformHandleTypeInterface:$target_linalg_op, + DefaultValuedOptionalAttr:$outer_perm, + DefaultValuedOptionalAttr:$inner_perm); + let results = (outs TransformHandleTypeInterface:$packed_op, + TransformHandleTypeInterface:$pack_op, + TransformHandleTypeInterface:$un_pack_op); + let assemblyFormat = [{ + $target_pack_or_unpack_op + `with_compute_op` `(` $target_linalg_op `)` + (`outer_perm` `=` $outer_perm^ )? + (`inner_perm` `=` $inner_perm^ )? + attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + transform::TransformResults &transformResults, + transform::TransformState &state); + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // PadOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1776,6 +1776,21 @@ static Value createDestinationTensor(OpBuilder &b, Location loc, Value source, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm); + + /// Build and return a new PackOp that is a clone of the current PackOp with + /// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by + /// innerPermutation (resp. outerPermutation). + /// A new `tensor.empty` of the proper shape is built in the process. + /// Asserts that: + /// - At least one of innerPermutation or outerPermutation is non-empty. + /// - If not empty, innerPermutation is a valid permutation of size + /// matching innerDimPos. + /// - If not empty, outerPermutation is a valid permutation of size + /// matching outerDimsPerm. + PackOp createTransposedClone(OpBuilder &b, + Location loc, + ArrayRef innerPermutation, + ArrayRef outerPermutation); }]; let hasCanonicalizeMethod = 1; @@ -1832,7 +1847,23 @@ CArg<"ArrayRef", "{}">:$outerDimsPerm)> ]; - let extraClassDeclaration = commonExtraClassDeclaration; + let extraClassDeclaration = commonExtraClassDeclaration # [{ + /// Build and return a new UnPackOp that is a clone of the current UnPackOp + /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by + /// innerPermutation (resp. outerPermutation). + /// Asserts that: + /// - At least one of innerPermutation or outerPermutation is non-empty. + /// - If not empty, innerPermutation is a valid permutation of size + /// matching innerDimPos. + /// - If not empty, outerPermutation is a valid permutation of size + /// matching outerDimsPerm. + UnPackOp createTransposedClone(OpBuilder &b, + Location loc, + Value transposedSource, + ArrayRef innerPermutation, + ArrayRef outerPermutation); + }]; + let hasCanonicalizeMethod = 1; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -17,16 +17,20 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/Dialect/Transform/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" @@ -1100,7 +1104,7 @@ // Fail on multi-op handles. auto linalgOp = dyn_cast(targetOps.front()); if (targetOps.size() != 1 || !linalgOp) { - // TODO: remove this unnecessary set to empty once crashes are fixed. + // TODO: remove this unnecessary "set to empty" once crashes are fixed. transformResults.set(getPackedOp().cast(), {}); return emitSilenceableError() << "requires target to map to exactly 1 LinalgOp (got " @@ -1108,7 +1112,7 @@ } // Fail on mismatched number of pack sizes. if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) { - // TODO: remove this unnecessary set to empty once crashes are fixed. + // TODO: remove this unnecessary "set to empty" once crashes are fixed. transformResults.set(getPackedOp().cast(), {}); return emitSilenceableError() << "requires number of packed sizes match the number of loops (" @@ -1135,6 +1139,272 @@ transform::producesHandle(getPackedOp(), effects); } +//===---------------------------------------------------------------------===// +// PackTransposeOp +//===---------------------------------------------------------------------===// + +void transform::PackTransposeOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getTargetPackOrUnpackOp(), effects); + transform::consumesHandle(getTargetLinalgOp(), effects); + transform::producesHandle(getPackedOp(), effects); + transform::producesHandle(getPackOp(), effects); + transform::producesHandle(getUnPackOp(), effects); +} + +namespace { +enum class OuterOrInnerPerm { Outer = 0, Inner = 1 }; +} // namespace + +/// Return true if `permutation` is a valid permutation of the `outer_dims_perm` +/// (case OuterOrInnerPerm::Outer) or `inner_dims_pos` (OuterOrInnerPerm::Inner) +/// of the `tensor.pack` or `tensor.unpack` `op. +/// This is the case when the `permutation` rank matches the rank expected by +/// `op` and `permutation` is itself a permutation vector. +/// Return true if either `op` or `permutation` are empty to allow a simpler +/// polymorphic implementation. +template +bool isValidPackingPermutation( + RelayoutOpTy op, ArrayRef permutation, + OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) { + static_assert( + llvm::is_one_of::value, + "applies to only pack or unpack operations"); + if (!op || permutation.empty()) + return true; + int64_t innerRank = op.getInnerDimsPos().size(); + if (outerOrInnerPerm == OuterOrInnerPerm::Inner) + return permutation.size() == innerRank && isPermutationVector(permutation); + // op.getOuterDimsPerm() may be empty, in which case it is identity. + // Don't rely on it. + if (std::is_same::value) { + return permutation.size() == op.getSourceRank() && + isPermutationVector(permutation); + } + return permutation.size() == op.getDestRank() && + isPermutationVector(permutation); +} + +/// Return a copy of `tensorType` after permutation by `permutationVector`. +// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder +// but this would introduce a dependence on Dialect in IR. +// TODO: Restructure. +static RankedTensorType permuteShape(RankedTensorType tensorType, + ArrayRef permutationVector) { + SmallVector shape(tensorType.getShape()); + applyPermutationToVector(shape, permutationVector); + return RankedTensorType::Builder(tensorType).setShape(shape); +} + +/// Return a new GenericOp obtained by transposing opOperand by the permutation +/// vector: +/// - the corresponding indexing map is transposed by `permutation` +/// - the corresponding operand value is replaced by `transposedValue` +/// Asserts that `transposedValue` is of the proper transposed ShapedType. +static LinalgOp transposeOneLinalgOperand(RewriterBase &rewriter, + LinalgOp linalgOp, + OpOperand &opOperand, + ArrayRef permutation, + Value transposedValue) { + // Sanity check the operand. + assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand"); + + // Sanity check of the expected transposed tensor type. + auto tensorType = permuteShape( + opOperand.get().getType().cast(), permutation); + assert(tensorType == transposedValue.getType() && + "expected tensor type mismatch"); + + // Compute the transposed indexing map. + // Sigh unsigned pollution. + SmallVector tmpTransposition = llvm::to_vector( + llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; })); + AffineMap permutationMap = + AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext()); + AffineMap transposedMap = + permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand)); + + // Set the transposed indexing map in the proper position. + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; + // Set the transposedValue in the proper operand position. + SmallVector operands = linalgOp->getOperands(); + operands[opOperand.getOperandNumber()] = transposedValue; + + ValueRange operandsRef(operands); + auto transposedGenericOp = rewriter.create( + /*location=*/linalgOp->getLoc(), + /*resultTensorTypes=*/ + operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(), + /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()), + /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); + transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0)); + + return cast(transposedGenericOp.getOperation()); +} + +LogicalResult transform::PackTransposeOp::verify() { + if (!isPermutationVector(getInnerPerm())) { + return emitOpError() << getInnerPermAttrName() + << " is not a valid permutation"; + } + if (!isPermutationVector(getOuterPerm())) { + return emitOpError() << getOuterPermAttrName() + << " is not a valid permutation"; + } + if (getInnerPerm().empty() && getOuterPerm().empty()) { + return emitOpError() << " at least one of " << getInnerPermAttrName() + << " or " << getOuterPermAttrName() + << " must be specified"; + } + return success(); +} + +DiagnosedSilenceableFailure +transform::PackTransposeOp::apply(transform::TransformResults &transformResults, + transform::TransformState &state) { + ArrayRef packOrUnpackOps = + state.getPayloadOps(getTargetPackOrUnpackOp()); + ArrayRef linalgOps = state.getPayloadOps(getTargetLinalgOp()); + // Step 1. If nothing to pack, propagate success. + if (packOrUnpackOps.empty()) { + transformResults.set(getPackedOp().cast(), {}); + transformResults.set(getPackOp().cast(), {}); + transformResults.set(getUnPackOp().cast(), {}); + return DiagnosedSilenceableFailure::success(); + } + + // Step 2. Bunch of runtime sanity check and error messages. + // Fail on multi-op handles. + if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) { + return emitSilenceableError() + << "requires target to map to exactly 1 " + "tensor.pack or tensor.unpack and exactly 1 LinalgOp (got " + << packOrUnpackOps.size() << " and " << linalgOps.size() << ")"; + } + // Fail on wrong type. + auto packOp = dyn_cast(packOrUnpackOps.front()); + auto unPackOp = dyn_cast(packOrUnpackOps.front()); + if ((!packOp && !unPackOp)) { + // TODO: remove this unnecessary "set to empty" once crashes are fixed. + transformResults.set(getPackedOp().cast(), {}); + transformResults.set(getPackOp().cast(), {}); + transformResults.set(getUnPackOp().cast(), {}); + return emitSilenceableError() << "requires target to map to exactly 1 " + "tensor.pack or tensor.unpack (got " + << packOrUnpackOps.size() << ")"; + } + // Fail if we can't get the producer / consumer Linalg op. + LinalgOp linalgOp; + if (packOp && packOp.getResult().hasOneUse()) + linalgOp = dyn_cast(*(packOp.getResult().getUsers().begin())); + else if (unPackOp) + linalgOp = unPackOp.getSource().getDefiningOp(); + if (!linalgOp || linalgOp != linalgOps.front()) { + // TODO: remove this unnecessary "set to empty" once crashes are fixed. + transformResults.set(getPackedOp().cast(), {}); + transformResults.set(getPackOp().cast(), {}); + transformResults.set(getUnPackOp().cast(), {}); + if (packOp) + return emitSilenceableError() + << "not a single use by the specified LinalgOp"; + return emitSilenceableError() << "not produced by the specified LinalgOp"; + } + // If we have an UnpackOp, we also need to fetch the symmetrical PackOp. + if (unPackOp) { + assert(!packOp && "packOp must be null on entry when unPackOp is not null"); + OpOperand *packUse = linalgOp.getDpsInitOperand( + unPackOp.getSource().cast().getResultNumber()); + packOp = dyn_cast_or_null(packUse->get().getDefiningOp()); + if (!packOp || !packOp.getResult().hasOneUse()) + return emitSilenceableError() << "could not find matching pack op"; + } + // Fail if any permutation does not validate. + for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) { + ArrayRef perm = + (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm(); + auto errorMsg = (permType == OuterOrInnerPerm::Outer) + ? StringLiteral{"invalid outer_perm"} + : StringLiteral{"invalid inner_perm"}; + if (!isValidPackingPermutation(packOp, perm, permType) || + !isValidPackingPermutation(unPackOp, perm, permType)) { + // TODO: remove this unnecessary "set to empty" once crashes are fixed. + transformResults.set(getPackedOp().cast(), {}); + transformResults.set(getPackOp().cast(), {}); + transformResults.set(getUnPackOp().cast(), {}); + Operation *packOrUnpackOp = + unPackOp ? unPackOp.getOperation() : packOp.getOperation(); + return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp; + } + } + + // From here on, packOp and linalgOp are always present, unPackOp may or may + // not be present. + assert(packOp && linalgOp && "unexpected null op"); + + // Step 3. Actually transpose the ops. + IRRewriter rewriter(getContext()); + + // Step 3.a. Transpose packOp. + rewriter.setInsertionPoint(packOp); + tensor::PackOp transposedPackOp = packOp.createTransposedClone( + rewriter, packOp.getLoc(), getInnerPerm(), getOuterPerm()); + + // Step 3.b. Transpose linalgOp. + assert(packOp.getResult().hasOneUse() && "expect single use"); + // packOp.getOuterDimsPerm() may be empty, in which case it is identity. + // Don't rely on it. + int64_t numLeadingDims = packOp.getSourceRank(); + int64_t numTrailingDims = packOp.getInnerDimsPos().size(); + // Step 3.b.i. Compute the permutation on the whole operand. + // Leading part just reuse the outerPerm. + SmallVector permutation(getOuterPerm()); + if (permutation.empty()) + llvm::append_range(permutation, llvm::seq(0, numLeadingDims)); + // Trailing part needs to reindex positions by `numLeadingDims`. + if (getInnerPerm().empty()) { + llvm::append_range( + permutation, + llvm::seq(numLeadingDims, numLeadingDims + numTrailingDims)); + } else { + llvm::append_range(permutation, + llvm::map_range(getInnerPerm(), [&](int64_t pos) { + return numLeadingDims + pos; + })); + } + assert(isPermutationVector(permutation) && "invalid permutation"); + // Step 3.b.ii. Actually perform the transposition. + OpOperand &packUse = *(packOp.getResult().getUses().begin()); + rewriter.setInsertionPoint(linalgOp); + linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperand( + rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); + + // Step 3.c. Maybe transpose unPackOp. + tensor::UnPackOp transposedUnPackOp; + if (unPackOp) { + OpResult transposedResult = linalgOp.getTiedOpResult(&packUse); + rewriter.setInsertionPoint(unPackOp); + transposedUnPackOp = unPackOp.createTransposedClone( + rewriter, packOp.getLoc(), transposedResult, getInnerPerm(), + getOuterPerm()); + } + + // Step 4. Replace and return results. + rewriter.replaceOp(packOp, transposedPackOp->getResults()); + rewriter.replaceOp(linalgOp, transposedLinalgOp->getResults()); + transformResults.set(getPackOp().cast(), {transposedPackOp}); + transformResults.set(getPackedOp().cast(), {transposedLinalgOp}); + if (unPackOp) { + rewriter.replaceOp(unPackOp, transposedUnPackOp->getResults()); + transformResults.set(getUnPackOp().cast(), {transposedUnPackOp}); + } else { + transformResults.set(getUnPackOp().cast(), {}); + } + return DiagnosedSilenceableFailure::success(); +} + //===---------------------------------------------------------------------===// // PadOp //===---------------------------------------------------------------------===// @@ -1299,7 +1569,7 @@ if (!target->hasTrait() && target->getNumRegions() > 0) return emitDefiniteFailure() - << "expected target that is isloated from above"; + << "expected target that is isolated from above"; } // Clone and replace. @@ -1478,9 +1748,9 @@ if (second.size() != first.size() && !second.empty()) { results.set(getFirst().cast(), {}); results.set(getSecond().cast(), {}); - auto diag = - emitSilenceableError() - << "splitting does not produce the second part for a subset of targets"; + auto diag = emitSilenceableError() + << "splitting does not produce the second part for a subset " + "of targets"; diag.attachNote() << "expected splitting to produce the second part of all " "or none of the targets"; diag.attachNote(noSecondPart->getLoc()) @@ -1732,8 +2002,8 @@ SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); // Call the default builder which sets up the proper operands segment sizes - // attributes for multiple variadic operands. In the absence of this, horrible - // bugs ensue. + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); @@ -1946,8 +2216,8 @@ SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); // Call the default builder which sets up the proper operands segment sizes - // attributes for multiple variadic operands. In the absence of this, horrible - // bugs ensue. + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); @@ -1983,8 +2253,8 @@ dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, staticNumThreads); // Call the default builder which sets up the proper operands segment sizes - // attributes for multiple variadic operands. In the absence of this, horrible - // bugs ensue. + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); @@ -2011,27 +2281,27 @@ // Transform all targets one by one. for (Operation *target : targets) { - auto tilableOp = dyn_cast(target); - if (!tilableOp) { + auto tileableOp = dyn_cast(target); + if (!tileableOp) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "only TilingInterface ops are supported"; diag.attachNote(target->getLoc()) << "target op"; return diag; } - rewriter.setInsertionPoint(tilableOp); + rewriter.setInsertionPoint(tileableOp); FailureOr tilingResult = failure(); if (!mixedNumThreads.empty()) { - tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, + tilingResult = linalg::tileToForeachThreadOp(rewriter, tileableOp, mixedNumThreads, mapping); } else { tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( - rewriter, tilableOp, mixedTileSizes, mapping); + rewriter, tileableOp, mixedTileSizes, mapping); } if (failed(tilingResult)) - return transformOp.emitDefaultSilenceableFailure(tilableOp); - rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults()); + return transformOp.emitDefaultSilenceableFailure(tileableOp); + rewriter.replaceOp(tileableOp, tilingResult->tileOp->getResults()); tileOps.push_back(tilingResult->tileOp); tiledOps.push_back(tilingResult->tiledOp); @@ -2117,8 +2387,8 @@ return emitOpError( "tile_sizes and packed_tile_sizes are mutually exclusive"); if (numThreadsSpec == 0 && tileSizesSpec == 0) - return emitOpError( - "either (packed_)num_threads or (packed_)tile_sizes must be specified"); + return emitOpError("either (packed_)num_threads or (packed_)tile_sizes " + "must be specified"); return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3231,7 +3231,6 @@ return true; } return shape == constTileSize.value(); - })) { return op->emitError("mismatch in inner tile sizes specified and shaped of " "tiled dimension in the packed type"); @@ -3239,6 +3238,57 @@ return success(); } +namespace { +/// Subset of PackOp/UnPackOp fields used to compute the result of applying +/// various permutations to the op. +// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse +// these. These may or may not become true foldings / canonicalizations +// depending on how aggressive we want to be in automatically folding +// transposes. +struct PackOrUnPackTransposeResult { + SmallVector innerDimsPos; + SmallVector innerTiles; + SmallVector outerDimsPerm; +}; +} // namespace + +template +static PackOrUnPackTransposeResult +commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, + ArrayRef innerPermutation, + ArrayRef outerPermutation) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + assert((!innerPermutation.empty() || !outerPermutation.empty()) && + "some permutation must be non-empty"); + PackOrUnPackTransposeResult metadata; + metadata.innerDimsPos = + SmallVector(packOrUnPackOp.getInnerDimsPos()); + metadata.innerTiles = + SmallVector(packOrUnPackOp.getMixedTiles()); + int64_t numOuterDims = std::is_same::value + ? packOrUnPackOp.getSourceRank() + : packOrUnPackOp.getDestRank(); + metadata.outerDimsPerm = + packOrUnPackOp.getOuterDimsPerm().empty() + ? llvm::to_vector(llvm::seq(0, numOuterDims)) + : SmallVector(packOrUnPackOp.getOuterDimsPerm()); + if (!innerPermutation.empty()) { + assert(innerPermutation.size() == metadata.innerDimsPos.size() && + isPermutationVector(innerPermutation) && + "invalid inner permutation"); + applyPermutationToVector(metadata.innerDimsPos, innerPermutation); + applyPermutationToVector(metadata.innerTiles, innerPermutation); + } + if (!outerPermutation.empty()) { + assert(outerPermutation.size() == metadata.outerDimsPerm.size() && + isPermutationVector(outerPermutation) && + "invalid outer permutation"); + applyPermutationToVector(metadata.outerDimsPerm, outerPermutation); + } + return metadata; +} + //===----------------------------------------------------------------------===// // PackOp //===----------------------------------------------------------------------===// @@ -3284,8 +3334,9 @@ return getStaticTilesImpl(*this); } -/// Check if we have enough static information to catch undefined behavior when -/// the tile size does not divide perfectly the dimension of the input tensor. +/// Check if we have enough static information to catch undefined behavior +/// when the tile size does not divide perfectly the dimension of the input +/// tensor. static bool areNotFullTiles(ArrayRef inputShape, DenseMap const &dimAndTileMapping) { @@ -3310,8 +3361,8 @@ return failure(); // Verify padding value, and bail out if the tile does not divide the - // dimension fully. In the case of dynamic tile factors or dimensions, having - // a partial tile is undefined behavior. + // dimension fully. In the case of dynamic tile factors or dimensions, + // having a partial tile is undefined behavior. auto paddingValue = getPaddingValue(); if (paddingValue && paddingValue.getType() != getSourceType().getElementType()) { @@ -3329,8 +3380,8 @@ return success(); } -/// Get the expected packed type based on source type, tile factors, position of -/// the inner tiles and permutation of the outer tiled loop. +/// Get the expected packed type based on source type, tile factors, position +/// of the inner tiles and permutation of the outer tiled loop. ShapedType PackOp::inferPackedType(ShapedType sourceType, ArrayRef innerTileSizes, ArrayRef innerDimsPos, @@ -3386,6 +3437,19 @@ return b.create(loc, mixedSizes, elemType); } +PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, + ArrayRef innerPermutation, + ArrayRef outerPermutation) { + PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp( + *this, innerPermutation, outerPermutation); + Value transposedDest = + createDestinationTensor(b, loc, getSource(), metadata.innerTiles, + metadata.innerDimsPos, metadata.outerDimsPerm); + return b.create(loc, getSource(), transposedDest, + metadata.innerDimsPos, metadata.innerTiles, + getPaddingValue(), metadata.outerDimsPerm); +} + /// Returns true if the tiles and the tiled dims are constant. template bool areTilesAndTiledDimsAllConstant(OpTy op) { @@ -3408,9 +3472,9 @@ if (auto paddingValue = getPaddingValue()) return Speculation::Speculatable; - // The verifier rejects already operations if we can statically prove that the - // sizes of the tiles do not divide perfectly the dimension; thus, check only - // to have constant tiles and tiled inner dimensions. + // The verifier rejects already operations if we can statically prove that + // the sizes of the tiles do not divide perfectly the dimension; thus, check + // only to have constant tiles and tiled inner dimensions. if (!areTilesAndTiledDimsAllConstant(*this)) return Speculation::NotSpeculatable; @@ -3508,6 +3572,17 @@ builder.getDenseI64ArrayAttr(staticTileSizes)); } +UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc, + Value transposedSource, + ArrayRef innerPermutation, + ArrayRef outerPermutation) { + PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp( + *this, innerPermutation, outerPermutation); + return b.create(loc, transposedSource, getDest(), + metadata.innerDimsPos, metadata.innerTiles, + metadata.outerDimsPerm); +} + /// pack(unpack(x)) -> x LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, PatternRewriter &rewriter) { diff --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir @@ -49,21 +49,21 @@ iterator_types = ["reduction", "parallel"] } -// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> // CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d1)> // CHECK-LABEL: @col_reduction_2d_static // CHECK-SAME: %[[T0:.+]]: tensor<7x3xf16>, // CHECK-SAME: %[[T1:.+]]: tensor<3xf16> func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> { - // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x4xf16> + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<3x2x4xf16> // CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) - // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<2x3x4xf16> + // CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<3x2x4xf16> // CHECK-NOT: tensor.pack // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]] // CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"] - // CHECK-SAME: ins(%{{.*}} : tensor<2x3x4xf16>) + // CHECK-SAME: ins(%{{.*}} : tensor<3x2x4xf16>) // CHECK-SAME: outs(%{{.*}} : tensor<3xf16>) %2 = linalg.generic #col_reduction_2d_trait ins(%t0 : tensor<7x3xf16>) outs(%t1 : tensor<3xf16>) { ^bb0(%in: f16, %out: f16): @@ -78,8 +78,15 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - transform.structured.pack %0 packed_sizes = [4, 0] + %1 = transform.structured.pack %0 packed_sizes = [4, 0] : (!pdl.operation) -> (!transform.op<"linalg.generic">) + %pack = transform.get_producer_of_operand %1[0] + : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">) + %2, %pack_2, %empty_unpack_2 = + transform.structured.pack_transpose %pack with_compute_op(%1) + outer_perm = [1, 0] + : (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">) + -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !pdl.operation) } // ----- @@ -183,7 +190,7 @@ // K N n k // CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> // M N m n -// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d4, d3)> // CHECK-LABEL: @matmul // CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: tensor, @@ -196,19 +203,19 @@ // CHECK-SAME: : tensor -> tensor // CHECK: %[[PACK_B:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [3, 4] // CHECK-SAME: : tensor -> tensor - // CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3] - // CHECK-SAME: : tensor -> tensor + // CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2] + // CHECK-SAME: : tensor -> tensor // CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} // CHECK-SAME: ins(%{{.*}} : tensor, tensor) - // CHECK-SAME: outs(%{{.*}} : tensor) + // CHECK-SAME: outs(%{{.*}} : tensor) %0 = linalg.matmul ins(%A, %B: tensor, tensor) outs(%C: tensor) -> tensor - // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3] - // CHECK-SAME: : tensor -> tensor + // CHECK: tensor.unpack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2] + // CHECK-SAME: : tensor -> tensor return %0 : tensor } @@ -218,6 +225,14 @@ // M N K %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4] : (!pdl.operation) -> (!transform.op<"linalg.generic">) + + %unpack = transform.get_consumers_of_result %1[0] + : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">) + %2, %pack_2, %unpack_2 = + transform.structured.pack_transpose %unpack with_compute_op(%1) + outer_perm = [1, 0] inner_perm = [1, 0] + : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">) + -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">) } // -----