diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1326,8 +1326,14 @@ RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion); +/// Function type which is used to control propagation of tensor.pack/unpack +/// ops. +using ControlPropagationFn = std::function; + /// Patterns to bubble up or down data layout ops across other operations. -void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns); +void populateDataLayoutPropagationPatterns( + RewritePatternSet &patterns, + const ControlPropagationFn &controlPackUnPackPropagation); /// Pattern to remove dead operands and results of `linalg.generic` operations. /// This is effectively DCE for a linalg op. diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1874,6 +1874,10 @@ ]; let extraClassDeclaration = commonExtraClassDeclaration # [{ + static Value createDestinationTensor(OpBuilder &b, Location loc, + Value source, ArrayRef innerTileSizes, + ArrayRef innerDimsPos, ArrayRef outerDimsPerm); + /// Build and return a new UnPackOp that is a clone of the current UnPackOp /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by /// innerPermutation (resp. outerPermutation). diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -32,6 +32,13 @@ namespace { +static bool hasGatherSemantics(linalg::GenericOp genericOp) { + for (Operation &op : genericOp.getBody()->getOperations()) + if (isa(op)) + return true; + return false; +} + // The struct contains the infomation about mapping packing information to // the iteration domain of Linalg ops. struct PackInfo { @@ -48,12 +55,19 @@ }; template -static PackInfo getPackingInfoFromOperand(AffineMap indexingMap, - OpTy packOrUnPackOp) { +static FailureOr +getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, + OpTy packOrUnPackOp) { static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); LLVM_DEBUG( { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); + + AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector iterators = + genericOp.getIteratorTypesArray(); + PackInfo packInfo; int64_t origNumDims = indexingMap.getNumDims(); SmallVector exprs(indexingMap.getResults()); @@ -61,8 +75,13 @@ for (auto [index, innerDimPos, tileSize] : llvm::zip_equal(llvm::seq(0, innerDimsPos.size()), innerDimsPos, packOrUnPackOp.getMixedTiles())) { + auto expr = exprs[innerDimPos]; + if (!expr.template isa()) + return failure(); int64_t domainDimPos = exprs[innerDimPos].template cast().getPosition(); + if (!isParallelIterator(iterators[domainDimPos])) + return failure(); packInfo.tiledDimsPos.push_back(domainDimPos); packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; @@ -75,9 +94,54 @@ }); } - for (auto dim : packOrUnPackOp.getOuterDimsPerm()) - packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim)); - if (!packInfo.outerDimsOnDomainPerm.empty()) { + // Bail out if a tiled dimension is present in a map but not as an affine dim + // expression. + OpBuilder b(packOrUnPackOp.getContext()); + for (AffineMap map : indexingMaps) { + for (int64_t i : packInfo.tiledDimsPos) { + for (AffineExpr expr : map.getResults()) { + if (expr.isFunctionOfDim(i) && !expr.isa()) + return failure(); + } + } + } + + // Get the outer dims perm on the iteration domain. Start by identifying the + // set of domain dims affected by the outer permutation along with the + // permuted ordering for those dims. Then the full outer dims permutation can + // be constructed by replacing the affected dims with the permuted result in a + // numLoops-rank identity. e.g. + // outerDimsPerm = [1, 2, 0] + // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) + // + // permutedOuterDims = [4, 3, 1] + // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] + // + // Non-affine dim expressions must not be permuted by the outer dims + // permutation. + SmallVector permutedOuterDims; + for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { + auto permutedExpr = indexingMap.getResult(dim); + if (auto dimExpr = permutedExpr.template dyn_cast()) { + auto dimPos = dimExpr.getPosition(); + permutedOuterDims.push_back(dimPos); + continue; + } + + // TODO: Allow propagation with transposes on non affine dim expressions, + // e.g. d0 + d1 which implies transposing both dims simultaneously while + // maintaining the relative position between them. + if (static_cast(index) != dim) + return failure(); + } + if (!permutedOuterDims.empty()) { + int64_t outerDimIndex = 0; + llvm::DenseSet permutedDomainDims(permutedOuterDims.begin(), + permutedOuterDims.end()); + for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) + packInfo.outerDimsOnDomainPerm.push_back( + permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] + : i); LLVM_DEBUG({ llvm::dbgs() << "map outer dimsDimsPerm to "; for (auto dim : packInfo.outerDimsOnDomainPerm) @@ -107,8 +171,13 @@ SmallVector outerDimsPerm; DenseMap currentPositionTileLoops; for (auto [pos, expr] : llvm::enumerate(exprs)) { - unsigned posInDomain = expr.cast().getPosition(); - currentPositionTileLoops[posInDomain] = pos; + // Here we rely on the assumption that the outer dims permutation + // when propagating currently requires that non-affine dim expressions + // are not permuted, thus allowing the identity assignment below. + if (auto dimExpr = expr.dyn_cast()) + currentPositionTileLoops[dimExpr.getPosition()] = pos; + else + currentPositionTileLoops[pos] = pos; } for (int64_t loopIdx : perm) { if (currentPositionTileLoops.count(loopIdx)) @@ -169,8 +238,6 @@ domainDimToOperandDim[dimPos] = index; continue; } - assert(expr.isa() && - "Found non-constant and non-affine dim expression"); } SmallVector innerDimsPos; SmallVector innerTileSizes; @@ -212,7 +279,7 @@ auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); // The operand does not have dimensions that relates to pack op. - if (innerDimsPos.empty()) + if (innerDimsPos.empty() && outerDimsPerm.empty()) return std::make_tuple(opOperand->get(), indexingMap); auto empty = tensor::PackOp::createDestinationTensor( @@ -252,7 +319,7 @@ return newGenericOp; } -/// Bubbles up tensor.pack op through elementwise generic op. This +/// Bubbles up tensor.pack op through a producer generic op. This /// swap pack(generic) to generic(pack). The new generic op works on packed /// domain; pack ops are created for input and output operands. E.g., /// @@ -296,10 +363,20 @@ /// linalg.yield %4 : f32 /// } -> tensor static FailureOr -bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter, - tensor::PackOp packOp) { +bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, + ControlPropagationFn controlFn) { auto genericOp = packOp.getSource().getDefiningOp(); - if (!genericOp || !isElementwise(genericOp)) + if (!genericOp) + return failure(); + + // User controlled propagation function. + if (!controlFn(genericOp)) + return failure(); + + // TODO: Enable propagation in the presence of linalg.index and + // tensor.extract, likely as a separate pattern as the pack information and + // propagation decision needs to be inferred from the region of the generic. + if (hasGatherSemantics(genericOp)) return failure(); // TODO: Relax the restriction. We are able to bubble up the pack op through @@ -309,6 +386,8 @@ // Bail-out if the result of the generic has multiple uses, as bubbling up // creates recomputation if the generic has multiple users. + // TODO: Enable the case where every use is an identical pack op as no + // recomputation is needed in that case. if (!genericOp->getResult(0).hasOneUse()) return failure(); @@ -343,12 +422,13 @@ return failure(); OpOperand *opOperand = genericOp.getDpsInitOperand(0); - auto packInfo = getPackingInfoFromOperand( - genericOp.getMatchingIndexingMap(opOperand), packOp); + auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); + if (failed(packInfo)) + return failure(); // Rebuild the indexing map for the corresponding init operand. auto [packedOutOperand, packedOutIndexingMap] = - getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo, + getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, genericOp, opOperand); // We'll replace the init operand with the destination of pack op if the init @@ -360,22 +440,29 @@ : packedOutOperand; return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap, - packInfo); + *packInfo); } -/// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method. -struct BubbleUpPackOpThroughElemGenericOpPattern +/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. +struct BubbleUpPackOpThroughGenericOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +public: + BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, + ControlPropagationFn fun) + : OpRewritePattern(context), controlFn(std::move(fun)) {} LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override { - auto genericOp = bubbleUpPackOpThroughElemGenericOp(rewriter, packOp); + auto genericOp = + bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); if (failed(genericOp)) return failure(); rewriter.replaceOp(packOp, genericOp->getResults()); return success(); } + +private: + ControlPropagationFn controlFn; }; // TODO: Relax this restriction. We should unpack an elementwise also @@ -431,13 +518,13 @@ /// inner_dims_pos = [3] inner_tiles = [32] into %0 /// static FailureOr> -pushDownUnPackOpThroughElemGenericOp(RewriterBase &rewriter, - GenericOp genericOp) { - if (!isElementwise(genericOp)) - return failure(); +pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) { if (genericOp.getNumResults() != 1) return failure(); + if (hasGatherSemantics(genericOp)) + return failure(); + // Collect the unPacked operand, if present. auto maybeUnPackedOperand = getUnPackedOperand(genericOp); if (failed(maybeUnPackedOperand)) @@ -448,13 +535,16 @@ tensor::UnPackOp producerUnPackOp = unPackedOperand->get().getDefiningOp(); assert(producerUnPackOp && "expect a valid UnPackOp"); - auto packInfo = getPackingInfoFromOperand( - genericOp.getMatchingIndexingMap(unPackedOperand), producerUnPackOp); + auto packInfo = + getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); + if (failed(packInfo)) + return failure(); // Rebuild the indexing map for the corresponding init operand. auto [packedOutOperand, packedOutIndexingMap] = - getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo, + getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, genericOp, genericOp.getDpsInitOperand(0)); + auto destPack = packedOutOperand.getDefiningOp(); // If the dps init operand of the generic is a tensor.empty, do not pack it // and forward the new tensor.empty as a destination. @@ -462,66 +552,76 @@ if (auto initTensor = genericOp.getDpsInitOperand(0) ->get() .getDefiningOp()) { - if (auto packOp = packedOutOperand.getDefiningOp()) - dest = packOp.getDest(); + if (destPack) + dest = destPack.getDest(); } // Pack the genericOp. GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest, - packedOutIndexingMap, packInfo); + packedOutIndexingMap, *packInfo); + Value newResult = + newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); + + // If the output is unaffected, no need to unpack. + if (!destPack) + return std::make_tuple(newGenericOp, newResult); - // If the output element type for the generic differs from the source - // unpack op, we need to create a new destination tensor. + auto mixedTiles = destPack.getMixedTiles(); + auto innerDimsPos = destPack.getInnerDimsPos(); + auto outerDimsPerm = destPack.getOuterDimsPerm(); + + // If the output type for the generic differs from the source + // unpack op, we need to create a new destination tensor. In the + // dynamic case we always need a new destination. auto loc = genericOp.getLoc(); Value unPackDest = producerUnPackOp.getDest(); - auto genericOutElementType = getElementTypeOrSelf(genericOp.getResult(0)); - if (producerUnPackOp.getDestType().getElementType() != - genericOutElementType) { - SmallVector unPackMixedSizes; - if (auto unPackEmpty = unPackDest.getDefiningOp()) - unPackMixedSizes = unPackEmpty.getMixedSizes(); - else - unPackMixedSizes = tensor::getMixedSizes(rewriter, loc, unPackDest); - - unPackDest = rewriter.create(loc, unPackMixedSizes, - genericOutElementType); + auto genericOutType = + genericOp.getDpsInitOperand(0)->get().getType().cast(); + if (producerUnPackOp.getDestType() != genericOutType || + !genericOutType.hasStaticShape()) { + unPackDest = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm); } // Insert an unPackOp right after the packed generic. Value unPackOpRes = rewriter - .create( - loc, - newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)), - unPackDest, producerUnPackOp.getInnerDimsPos(), - producerUnPackOp.getMixedTiles(), - producerUnPackOp.getOuterDimsPerm()) + .create(loc, newResult, unPackDest, innerDimsPos, + mixedTiles, outerDimsPerm) .getResult(); return std::make_tuple(newGenericOp, unPackOpRes); } -// Wrapper pattern that applies pushDownUnPackOpThroughElemGenericOp method. -struct PushDownUnPackOpThroughElemGenericOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. +struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern { +public: + PushDownUnPackOpThroughGenericOp(MLIRContext *context, + ControlPropagationFn fun) + : OpRewritePattern(context), controlFn(std::move(fun)) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - auto genericAndRepl = - pushDownUnPackOpThroughElemGenericOp(rewriter, genericOp); + if (!controlFn(genericOp)) + return failure(); + + auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp); if (failed(genericAndRepl)) return failure(); rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); return success(); } + +private: + ControlPropagationFn controlFn; }; /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to /// add as many zero padding dimensions in `high` and `low` based on the number /// of point loops. struct PushDownUnPackThroughPadOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) + : OpRewritePattern(context), controlFn(std::move(fun)) {} LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override { @@ -530,6 +630,9 @@ if (!unpackOp) return failure(); + if (!controlFn(padOp)) + return failure(); + Location loc = padOp.getLoc(); // Bail out if one of the padded dimension is a tiled one. llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); @@ -572,14 +675,17 @@ rewriter.replaceOp(padOp, replacement); return success(); } + +private: + ControlPropagationFn controlFn; }; } // namespace void mlir::linalg::populateDataLayoutPropagationPatterns( - RewritePatternSet &patterns) { - patterns - .insert( - patterns.getContext()); + RewritePatternSet &patterns, + const ControlPropagationFn &controlPackUnPackPropagation) { + patterns.insert( + patterns.getContext(), controlPackUnPackPropagation); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3760,6 +3760,41 @@ builder.getDenseI64ArrayAttr(staticTileSizes)); } +Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc, + Value source, + ArrayRef innerTileSizes, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + AffineExpr sym0, sym1; + bindSymbols(b.getContext(), sym0, sym1); + auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { + return makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2}); + }; + + SmallVector mixedSizes; + auto srcType = source.getType().cast(); + for (auto i : + llvm::seq(0, srcType.getRank() - innerTileSizes.size())) { + if (srcType.isDynamicDim(i)) + mixedSizes.push_back(b.create(loc, source, i).getResult()); + else + mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i))); + } + if (!outerDimsPerm.empty()) { + applyPermutationToVector( + mixedSizes, invertPermutationVector(outerDimsPerm)); + } + + for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) { + int64_t dimPos = std::get<0>(it); + OpFoldResult tileSize = std::get<1>(it); + mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize); + } + + auto elemType = source.getType().cast().getElementType(); + return b.create(loc, mixedSizes, elemType); +} + UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc, Value transposedSource, ArrayRef innerPermutation, diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -323,9 +323,6 @@ // ----- -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> -#map2 = affine_map<(d0, d1) -> (d1)> func.func @transpose_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32> { %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> @@ -679,3 +676,139 @@ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[ARG0]] // CHECK-SAME: outs(%[[EMPTY]] + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0 : tensor<128x256x32xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg4 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + inner_dims_pos = [1, 0] + inner_tiles = [16, 32] + into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> + return %pack : tensor<4x16x16x32xi32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +// CHECK: func.func @reduction_pack_transpose_inner_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ORIG_INIT:.+]] = tensor.empty() : tensor<128x256xi32> +// CHECK: %[[INIT_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32> +// CHECK: %[[PACK_INIT:.+]] = tensor.pack %[[ORIG_INIT]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16x32xi32> +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[RED:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]] +// CHECK-SAME: outs(%[[PACK_INIT]] +// CHECK: return %[[RED]] : tensor<4x16x16x32xi32> + +// ----- + +func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>) -> tensor<4x16x100x16x32xi32> +{ + %init_reduction = tensor.empty() : tensor<100x128x256xi32> + %reduction = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0)>, + affine_map<(d0, d1, d2, d3) -> (d1)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>], + iterator_types = ["parallel", "parallel", "reduction", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>) + outs(%init_reduction : tensor<100x128x256xi32>) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): + %0 = arith.addi %b0, %b1 : i32 + %1 = arith.addi %0, %b2 : i32 + %2 = arith.addi %1, %b3 : i32 + linalg.yield %2 : i32 + } -> tensor<100x128x256xi32> + %init_pack = tensor.empty() : tensor<4x16x100x16x32xi32> + %4 = tensor.pack %reduction + outer_dims_perm = [1, 2, 0] + inner_dims_pos = [2, 1] + inner_tiles = [16, 32] + into %init_pack : tensor<100x128x256xi32> -> tensor<4x16x100x16x32xi32> + return %4 : tensor<4x16x100x16x32xi32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d5)> +// CHECK: func.func @reduction_pack_with_outer_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<100x128x256xi32> +// CHECK: %[[INIT_EMPTY:.+]] = tensor.empty() : tensor<4x16x100x16x32xi32> +// CHECK: %[[PACKED_INIT:.+]] = tensor.pack %[[INIT]] +// CHECK-SAME: outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 32] +// CHECK-SAME: into %[[INIT_EMPTY]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x200x100x16x32xi32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [1, 3, 2, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32> +// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[ARG2_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] +// CHECK-SAME: outs(%[[PACKED_INIT]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)> +func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>) -> tensor<16x540x960xi32>{ + %init = tensor.empty() : tensor<16x540x960xi32> + %filter = tensor.empty() : tensor<2x2xi32> + %empty = tensor.empty() : tensor<1x16x1080x1920xi32> + %unpack = tensor.unpack %arg0 + inner_dims_pos = [1] + inner_tiles = [16] + into %empty : tensor<1x1x1080x1920x16xi32> -> tensor<1x16x1080x1920xi32> + %pool = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} + ins(%unpack, %filter : tensor<1x16x1080x1920xi32>, tensor<2x2xi32>) + outs(%init : tensor<16x540x960xi32>) { + ^bb0(%in: i32, %in_1: i32, %out: i32): + %max = arith.maxui %in, %out : i32 + linalg.yield %max : i32 + } -> tensor<16x540x960xi32> + return %pool : tensor<16x540x960xi32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5, d6)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d3, d6)> +// CHECK: func.func @unpack_different_destination_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[FILTER:.+]] = tensor.empty() : tensor<2x2xi32> +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32> +// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32> +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] +// CHECK-SAME: into %[[PACK_EMPTY]] +// CHECK: %[[POOL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]], %[[FILTER]] +// CHECK-SAME: outs(%[[INIT]] +// CHECK: %[[UNPACK_NEW_DEST:.+]] = tensor.empty() : tensor<16x540x960xi32> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[POOL]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] +// CHECK-SAME: into %[[UNPACK_NEW_DEST]] +// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32> diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp --- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp @@ -32,7 +32,8 @@ void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - linalg::populateDataLayoutPropagationPatterns(patterns); + linalg::populateDataLayoutPropagationPatterns( + patterns, [](Operation *op) { return true; }); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure();