diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -428,6 +428,23 @@ result.getResultNumber()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the index in the indexingMaps vector that corresponds to this `opOperand` + }], + /*retTy=*/"int64_t", + /*methodName=*/"getIndexingMapIndex", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto operandNumber = opOperand->getOperandNumber(); + auto dpsIface = cast(*this->getOperation()); + if (!dpsIface.isDpsInput(opOperand)) return operandNumber; + auto [start, end] = dpsIface.getDpsInitsPositionRange(); + assert(!$_op.isDpsInit(opOperand)); + return operandNumber - start; + }] + >, InterfaceMethod< /*desc=*/[{ Return the value yielded by the region corresponding to an output diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -341,6 +341,9 @@ }]; } +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// def PackOp : Op,]> { @@ -425,6 +428,74 @@ }]; } +//===----------------------------------------------------------------------===// +// PackTransposeOp +//===----------------------------------------------------------------------===// +def PackTransposeOp : Op,]> { + let description = [{ + Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and + updates the `linalg.generic` op that consumes (resp. produces) the operation. + + This transform allows composing a simple `structured.pack` with additional + transpositions to e.g. match the data format required by a specific library + call or ISA instruction. + + The transpose spec must specify at least one of `outer_perm` or `inner_perm` + attributes, which will act upon the `outer_dims_perm` or `inner_dims_pos` of + the specified `tensor.pack` or `tensor.unpack` op. + + If the `target` of this op is a `tensor.pack` then a new `tensor.empty` will + be created along with transposed versions of the `tensor.pack` and the + consuming `linalg.generic`, which is expected to be the sole consumer. + + If the `target` of this op is a `tensor.unpack` then the whole pack / compute + / unpack chain will be transposed and transposed clones of `tensor.pack`, + the consuming `linalg.generic` and the tail `tensor.pack` will be created. + + #### Return modes + + This operation targets a single `tensor.pack` / `tensor.unpack` op and a + single matching `linalg.generic` that consumes / produces the op. Otherwise, + it produces a silenceableFailure. + + This operation may produce a silenceableFailure if the transpose spec is + ill-formed (i.e. `outer_perm` or `inner_perm` are not permutations of the + proper rank) or if the tranposition of all involved operations fails for any + reason. + + This operation returns 3 handles, one to the transformed LinalgOp, one to + the transformed `tensor.pack` and one to the transformed `tensor.unpack`. + The last handle for `tensor.unpack` is empty if `target_pack_or_unpack_op` + was not itself a `tensor.unpack`. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target_pack_or_unpack_op, + TransformHandleTypeInterface:$target_linalg_op, + DefaultValuedOptionalAttr:$outer_perm, + DefaultValuedOptionalAttr:$inner_perm); + let results = (outs TransformHandleTypeInterface:$packed_op, + TransformHandleTypeInterface:$pack_op, + TransformHandleTypeInterface:$un_pack_op); + let assemblyFormat = [{ + $target_pack_or_unpack_op + `with_compute_op` `(` $target_linalg_op `)` + (`outer_perm` `=` $outer_perm^ )? + (`inner_perm` `=` $inner_perm^ )? + attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + transform::TransformResults &transformResults, + transform::TransformState &state); + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // PadOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1776,6 +1776,21 @@ static Value createDestinationTensor(OpBuilder &b, Location loc, Value source, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm); + + /// Build and return a new PackOp that is a clone of the current PackOp with + /// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by + /// innerPermutation (resp. outerPermutation). + /// A new `tensor.empty` of the proper shape is built in the process. + /// Asserts that: + /// - At least one of innerPermutation or outerPermutation is non-empty. + /// - If not empty, innerPermutation is a valid permutation of size + /// matching innerDimPos. + /// - If not empty, outerPermutation is a valid permutation of size + /// matching outerDimsPerm. + PackOp createTransposedClone(OpBuilder &b, + Location loc, + ArrayRef innerPermutation, + ArrayRef outerPermutation); }]; let hasCanonicalizeMethod = 1; @@ -1832,7 +1847,23 @@ CArg<"ArrayRef", "{}">:$outerDimsPerm)> ]; - let extraClassDeclaration = commonExtraClassDeclaration; + let extraClassDeclaration = commonExtraClassDeclaration # [{ + /// Build and return a new UnPackOp that is a clone of the current UnPackOp + /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by + /// innerPermutation (resp. outerPermutation). + /// Asserts that: + /// - At least one of innerPermutation or outerPermutation is non-empty. + /// - If not empty, innerPermutation is a valid permutation of size + /// matching innerDimPos. + /// - If not empty, outerPermutation is a valid permutation of size + /// matching outerDimsPerm. + UnPackOp createTransposedClone(OpBuilder &b, + Location loc, + Value transposedSource, + ArrayRef innerPermutation, + ArrayRef outerPermutation); + }]; + let hasCanonicalizeMethod = 1; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -17,16 +17,20 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/Dialect/Transform/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" @@ -1100,7 +1104,7 @@ // Fail on multi-op handles. auto linalgOp = dyn_cast(targetOps.front()); if (targetOps.size() != 1 || !linalgOp) { - // TODO: remove this unnecessary set to empty once crashes are fixed. + // TODO: remove this unnecessary "set to empty" once crashes are fixed. transformResults.set(getPackedOp().cast(), {}); return emitSilenceableError() << "requires target to map to exactly 1 LinalgOp (got " @@ -1108,7 +1112,7 @@ } // Fail on mismatched number of pack sizes. if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) { - // TODO: remove this unnecessary set to empty once crashes are fixed. + // TODO: remove this unnecessary "set to empty" once crashes are fixed. transformResults.set(getPackedOp().cast(), {}); return emitSilenceableError() << "requires number of packed sizes match the number of loops (" @@ -1135,6 +1139,248 @@ transform::producesHandle(getPackedOp(), effects); } +//===---------------------------------------------------------------------===// +// PackTransposeOp +//===---------------------------------------------------------------------===// + +void transform::PackTransposeOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getTargetPackOrUnpackOp(), effects); + transform::consumesHandle(getTargetLinalgOp(), effects); + transform::producesHandle(getPackedOp(), effects); + transform::producesHandle(getPackOp(), effects); + transform::producesHandle(getUnPackOp(), effects); +} + +namespace { +enum class OuterOrInnerPerm { Outer = 0, Inner = 1 }; +} // namespace + +/// Return true if either `op` or `permutation` are empty to allow a simpler +/// polymorphic implementation. +template +bool isValidPackingPermutation( + RelayoutOpTy op, ArrayRef permutation, + OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) { + static_assert( + llvm::is_one_of::value, + "applies to only pack or unpack operations"); + if (!op || permutation.empty()) + return true; + if (outerOrInnerPerm == OuterOrInnerPerm::Outer) + return permutation.size() == op.getOuterDimsPerm().size() && + isPermutationVector(permutation); + return permutation.size() == op.getInnerDimsPos().size() && + isPermutationVector(permutation); +} + +/// Should be a new method in of MemRef/RankedTensor/VectorType::Builder but +/// this would introduce a dependence on Dialect in IR. +// TODO: Restructure. +static RankedTensorType permuteShape(RankedTensorType tensorType, + ArrayRef transposition) { + SmallVector shape(tensorType.getShape()); + applyPermutationToVector(shape, transposition); + return RankedTensorType::Builder(tensorType).setShape(shape); +} + +/// Return a new GenericOp obtained by transposing opOperand by the +/// transposition vector: +/// - the corresponding indexing map is transposed by `transposition` +/// - the corresponding operand value is replaced by `transposedValue` +/// Asserts that `transposedValue` is of the proper transposed ShapedType. +static LinalgOp transposeOneLinalgOperand(RewriterBase &rewriter, + LinalgOp linalgOp, + OpOperand &opOperand, + ArrayRef transposition, + Value transposedValue) { + // Sanity check the operand. + assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand"); + + // Sanity check of the expected transposed tensor type. + auto tensorType = permuteShape( + opOperand.get().getType().cast(), transposition); + assert(tensorType == transposedValue.getType() && + "expected tensor type mismatch"); + + // Compute the transposed indexing map. + // Sigh unsigned pollution. + SmallVector tmpTransposition = llvm::to_vector( + llvm::map_range(transposition, [](int64_t i) -> unsigned { return i; })); + AffineMap permutationMap = + AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext()); + AffineMap transposedMap = + linalgOp.getMatchingIndexingMap(&opOperand).compose(permutationMap); + + // Set the transposed indexing map in the proper position. + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; + // Set the transposedValue in the proper operand position. + SmallVector operands = linalgOp->getOperands(); + operands[opOperand.getOperandNumber()] = transposedValue; + + ValueRange operandsRef(operands); + auto transposedGeneric = rewriter.create( + /*location=*/linalgOp->getLoc(), + /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()), + /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); + + return cast(transposedGeneric.getOperation()); +} + +LogicalResult transform::PackTransposeOp::verify() { + if (!isPermutationVector(getInnerPerm())) { + return emitOpError() << getInnerPermAttrName() + << " is not a valid permutation"; + } + if (!isPermutationVector(getOuterPerm())) { + return emitOpError() << getOuterPermAttrName() + << " is not a valid permutation"; + } + if (getInnerPerm().empty() && getOuterPerm().empty()) { + return emitOpError() << " at least one of " << getInnerPermAttrName() + << " or " << getOuterPermAttrName() + << " must be specified"; + } + return success(); +} + +DiagnosedSilenceableFailure +transform::PackTransposeOp::apply(transform::TransformResults &transformResults, + transform::TransformState &state) { + ArrayRef packOrUnpackOps = + state.getPayloadOps(getTargetPackOrUnpackOp()); + ArrayRef linalgOps = state.getPayloadOps(getTargetLinalgOp()); + // Step 1. If nothing to pack, propagate success. + if (packOrUnpackOps.empty()) { + transformResults.set(getPackedOp().cast(), {}); + transformResults.set(getPackOp().cast(), {}); + transformResults.set(getUnPackOp().cast(), {}); + return DiagnosedSilenceableFailure::success(); + } + + // Step 2. Bunch of runtime sanity check and error messages. + // Fail on multi-op handles. + if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) { + return emitSilenceableError() + << "requires target to map to exactly 1 " + "tensor.pack or tensor.unpack and exactly 1 LinalgOp (got " + << packOrUnpackOps.size() << " and " << linalgOps.size() << ")"; + } + // Fail on wrong type. + auto packOp = dyn_cast(packOrUnpackOps.front()); + auto unpackOp = dyn_cast(packOrUnpackOps.front()); + if ((!packOp && !unpackOp)) { + // TODO: remove this unnecessary "set to empty" once crashes are fixed. + transformResults.set(getPackedOp().cast(), {}); + transformResults.set(getPackOp().cast(), {}); + transformResults.set(getUnPackOp().cast(), {}); + return emitSilenceableError() << "requires target to map to exactly 1 " + "tensor.pack or tensor.unpack (got " + << packOrUnpackOps.size() << ")"; + } + // Fail if we can't get the producer / consumer Linalg op. + LinalgOp linalgOp; + if (packOp && packOp.getResult().hasOneUse()) + linalgOp = dyn_cast(*(packOp.getResult().getUsers().begin())); + else if (unpackOp) + linalgOp = unpackOp.getSource().getDefiningOp(); + if (!linalgOp || linalgOp != linalgOps.front()) { + // TODO: remove this unnecessary "set to empty" once crashes are fixed. + transformResults.set(getPackedOp().cast(), {}); + transformResults.set(getPackOp().cast(), {}); + transformResults.set(getUnPackOp().cast(), {}); + if (packOp) + return emitSilenceableError() + << "not a single use by the specified LinalgOp"; + return emitSilenceableError() << "not produced by the specified LinalgOp"; + } + // If we have an UnpackOp, we also need to fetch the symmetrical PackOp. + if (unpackOp) { + assert(!packOp && "packOp must be null on entry when unpackOp is not null"); + OpOperand *packUse = linalgOp.getDpsInitOperand( + unpackOp.getSource().cast().getResultNumber()); + packOp = dyn_cast_or_null(packUse->get().getDefiningOp()); + if (!packOp || !packOp.getResult().hasOneUse()) + return emitSilenceableError() << "could not find matching pack op"; + } + // Fail if any permutation does not validate. + for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) { + auto perm = + (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm(); + auto errorMsg = (permType == OuterOrInnerPerm::Outer) + ? StringLiteral{"invalid outer_perm"} + : StringLiteral{"invalid inner_perm"}; + if (!isValidPackingPermutation(packOp, perm, permType) || + !isValidPackingPermutation(unpackOp, perm, permType)) { + // TODO: remove this unnecessary "set to empty" once crashes are fixed. + transformResults.set(getPackedOp().cast(), {}); + transformResults.set(getPackOp().cast(), {}); + transformResults.set(getUnPackOp().cast(), {}); + return emitSilenceableError() << errorMsg; + } + } + + // From here on, packOp and linalgOp are always present, unpackOp may or may + // not be present. + assert(packOp && linalgOp && "unexpected null op"); + + // Step 3. Actually transpose the ops. + IRRewriter rewriter(getContext()); + + // Step 3.a. Transpose packOp. + tensor::PackOp transposedPackOp = packOp.createTransposedClone( + rewriter, packOp.getLoc(), getInnerPerm(), getOuterPerm()); + + // Step 3.b. Transpose linalgOp. + assert(packOp.getResult().hasOneUse() && "expect single use"); + int64_t numLeadingDims = packOp.getOuterDimsPerm().size(); + int64_t numTrailingDims = packOp.getInnerDimsPos().size(); + // Step 3.b.i. Compute the permutation on the whole operand. + // Leading part just reuse the outerPerm. + SmallVector permutation(getOuterPerm()); + if (permutation.empty()) + llvm::append_range(permutation, llvm::seq(0, numLeadingDims)); + // Trailing part needs to reindex positions by `numLeadingDims`. + if (getInnerPerm().empty()) { + llvm::append_range( + permutation, + llvm::seq(numLeadingDims, numLeadingDims + numTrailingDims)); + } else { + llvm::append_range(permutation, + llvm::map_range(getInnerPerm(), [&](int64_t pos) { + return numLeadingDims + pos; + })); + } + // Step 3.b.ii. Actually perform the transposition. + OpOperand &packUse = *(packOp.getResult().getUses().begin()); + linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperand( + rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); + + // Step 3.c. Maybe transpose unpackOp. + tensor::UnPackOp transposedUnPackOp; + if (unpackOp) { + OpResult transposedResult = linalgOp.getTiedOpResult(&packUse); + transposedUnPackOp = unpackOp.createTransposedClone( + rewriter, packOp.getLoc(), transposedResult, getInnerPerm(), + getOuterPerm()); + } + + // Step 4. Replace and return results. + rewriter.replaceOp(linalgOp, transposedLinalgOp->getResults()); + transformResults.set(getPackOp().cast(), {transposedPackOp}); + transformResults.set(getPackedOp().cast(), {transposedLinalgOp}); + if (unpackOp) { + rewriter.replaceOp(unpackOp, transposedUnPackOp->getResults()); + transformResults.set(getUnPackOp().cast(), {transposedUnPackOp}); + } else { + transformResults.set(getUnPackOp().cast(), {}); + } + return DiagnosedSilenceableFailure::success(); +} + //===---------------------------------------------------------------------===// // PadOp //===---------------------------------------------------------------------===// @@ -1299,7 +1545,7 @@ if (!target->hasTrait() && target->getNumRegions() > 0) return emitDefiniteFailure() - << "expected target that is isloated from above"; + << "expected target that is isolated from above"; } // Clone and replace. @@ -1478,9 +1724,9 @@ if (second.size() != first.size() && !second.empty()) { results.set(getFirst().cast(), {}); results.set(getSecond().cast(), {}); - auto diag = - emitSilenceableError() - << "splitting does not produce the second part for a subset of targets"; + auto diag = emitSilenceableError() + << "splitting does not produce the second part for a subset " + "of targets"; diag.attachNote() << "expected splitting to produce the second part of all " "or none of the targets"; diag.attachNote(noSecondPart->getLoc()) @@ -1732,8 +1978,8 @@ SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); // Call the default builder which sets up the proper operands segment sizes - // attributes for multiple variadic operands. In the absence of this, horrible - // bugs ensue. + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); @@ -1946,8 +2192,8 @@ SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); // Call the default builder which sets up the proper operands segment sizes - // attributes for multiple variadic operands. In the absence of this, horrible - // bugs ensue. + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); @@ -1983,8 +2229,8 @@ dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, staticNumThreads); // Call the default builder which sets up the proper operands segment sizes - // attributes for multiple variadic operands. In the absence of this, horrible - // bugs ensue. + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); @@ -2011,27 +2257,27 @@ // Transform all targets one by one. for (Operation *target : targets) { - auto tilableOp = dyn_cast(target); - if (!tilableOp) { + auto tileableOp = dyn_cast(target); + if (!tileableOp) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "only TilingInterface ops are supported"; diag.attachNote(target->getLoc()) << "target op"; return diag; } - rewriter.setInsertionPoint(tilableOp); + rewriter.setInsertionPoint(tileableOp); FailureOr tilingResult = failure(); if (!mixedNumThreads.empty()) { - tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, + tilingResult = linalg::tileToForeachThreadOp(rewriter, tileableOp, mixedNumThreads, mapping); } else { tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( - rewriter, tilableOp, mixedTileSizes, mapping); + rewriter, tileableOp, mixedTileSizes, mapping); } if (failed(tilingResult)) - return transformOp.emitDefaultSilenceableFailure(tilableOp); - rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults()); + return transformOp.emitDefaultSilenceableFailure(tileableOp); + rewriter.replaceOp(tileableOp, tilingResult->tileOp->getResults()); tileOps.push_back(tilingResult->tileOp); tiledOps.push_back(tilingResult->tiledOp); @@ -2117,8 +2363,8 @@ return emitOpError( "tile_sizes and packed_tile_sizes are mutually exclusive"); if (numThreadsSpec == 0 && tileSizesSpec == 0) - return emitOpError( - "either (packed_)num_threads or (packed_)tile_sizes must be specified"); + return emitOpError("either (packed_)num_threads or (packed_)tile_sizes " + "must be specified"); return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3231,7 +3231,6 @@ return true; } return shape == constTileSize.value(); - })) { return op->emitError("mismatch in inner tile sizes specified and shaped of " "tiled dimension in the packed type"); @@ -3239,6 +3238,45 @@ return success(); } +namespace { +/// Subset of PackOp/UnPackOp fields that ar +struct PackOrUnPackTransposeResult { + SmallVector innerDimsPos; + SmallVector innerTiles; + SmallVector outerDimsPerm; +}; +} // namespace + +template +static PackOrUnPackTransposeResult +commonPermutationOfPackAndUnPackOp(OpTy packOrUnPack, + ArrayRef innerPermutation, + ArrayRef outerPermutation) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + assert((!innerPermutation.empty() || !outerPermutation.empty()) && + "some permutation must be non-empty"); + PackOrUnPackTransposeResult metadata; + metadata.innerDimsPos = SmallVector(packOrUnPack.getInnerDimsPos()); + metadata.innerTiles = SmallVector(packOrUnPack.getMixedTiles()); + metadata.outerDimsPerm = + SmallVector(packOrUnPack.getOuterDimsPerm()); + if (!innerPermutation.empty()) { + assert(innerPermutation.size() == metadata.innerDimsPos.size() && + isPermutationVector(innerPermutation) && + "invalid inner permutation"); + applyPermutationToVector(metadata.innerDimsPos, innerPermutation); + applyPermutationToVector(metadata.innerTiles, innerPermutation); + } + if (!outerPermutation.empty()) { + assert(innerPermutation.size() == metadata.innerDimsPos.size() && + isPermutationVector(innerPermutation) && + "invalid outer permutation"); + applyPermutationToVector(metadata.outerDimsPerm, outerPermutation); + } + return metadata; +} + //===----------------------------------------------------------------------===// // PackOp //===----------------------------------------------------------------------===// @@ -3284,8 +3322,9 @@ return getStaticTilesImpl(*this); } -/// Check if we have enough static information to catch undefined behavior when -/// the tile size does not divide perfectly the dimension of the input tensor. +/// Check if we have enough static information to catch undefined behavior +/// when the tile size does not divide perfectly the dimension of the input +/// tensor. static bool areNotFullTiles(ArrayRef inputShape, DenseMap const &dimAndTileMapping) { @@ -3310,8 +3349,8 @@ return failure(); // Verify padding value, and bail out if the tile does not divide the - // dimension fully. In the case of dynamic tile factors or dimensions, having - // a partial tile is undefined behavior. + // dimension fully. In the case of dynamic tile factors or dimensions, + // having a partial tile is undefined behavior. auto paddingValue = getPaddingValue(); if (paddingValue && paddingValue.getType() != getSourceType().getElementType()) { @@ -3329,8 +3368,8 @@ return success(); } -/// Get the expected packed type based on source type, tile factors, position of -/// the inner tiles and permutation of the outer tiled loop. +/// Get the expected packed type based on source type, tile factors, position +/// of the inner tiles and permutation of the outer tiled loop. ShapedType PackOp::inferPackedType(ShapedType sourceType, ArrayRef innerTileSizes, ArrayRef innerDimsPos, @@ -3386,6 +3425,19 @@ return b.create(loc, mixedSizes, elemType); } +PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, + ArrayRef innerPermutation, + ArrayRef outerPermutation) { + PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp( + *this, innerPermutation, outerPermutation); + Value transposedDest = + createDestinationTensor(b, loc, getSource(), metadata.innerTiles, + metadata.innerDimsPos, metadata.outerDimsPerm); + return b.create(loc, getSource(), transposedDest, + metadata.innerDimsPos, metadata.innerTiles, + getPaddingValue(), metadata.outerDimsPerm); +} + /// Returns true if the tiles and the tiled dims are constant. template bool areTilesAndTiledDimsAllConstant(OpTy op) { @@ -3408,9 +3460,9 @@ if (auto paddingValue = getPaddingValue()) return Speculation::Speculatable; - // The verifier rejects already operations if we can statically prove that the - // sizes of the tiles do not divide perfectly the dimension; thus, check only - // to have constant tiles and tiled inner dimensions. + // The verifier rejects already operations if we can statically prove that + // the sizes of the tiles do not divide perfectly the dimension; thus, check + // only to have constant tiles and tiled inner dimensions. if (!areTilesAndTiledDimsAllConstant(*this)) return Speculation::NotSpeculatable; @@ -3508,6 +3560,17 @@ builder.getDenseI64ArrayAttr(staticTileSizes)); } +UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc, + Value transposedSource, + ArrayRef innerPermutation, + ArrayRef outerPermutation) { + PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp( + *this, innerPermutation, outerPermutation); + return b.create(loc, transposedSource, getDest(), + metadata.innerDimsPos, metadata.innerTiles, + metadata.outerDimsPerm); +} + /// pack(unpack(x)) -> x LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, PatternRewriter &rewriter) {