diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -46,19 +46,22 @@ SmallVector outerDimsOnDomainPerm; }; -static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap, - tensor::PackOp packOp) { +template +static PackInfo getPackingInfoFromOperand(AffineMap indexingMap, + OpTy packOrUnPackOp) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); LLVM_DEBUG( - { llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; }); + { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); PackInfo packInfo; int64_t origNumDims = indexingMap.getNumDims(); SmallVector exprs(indexingMap.getResults()); - ArrayRef innerDimsPos = packOp.getInnerDimsPos(); + ArrayRef innerDimsPos = packOrUnPackOp.getInnerDimsPos(); for (auto [index, innerDimPos, tileSize] : llvm::zip_equal(llvm::seq(0, innerDimsPos.size()), - innerDimsPos, packOp.getMixedTiles())) { + innerDimsPos, packOrUnPackOp.getMixedTiles())) { int64_t domainDimPos = - exprs[innerDimPos].cast().getPosition(); + exprs[innerDimPos].template cast().getPosition(); packInfo.tiledDimsPos.push_back(domainDimPos); packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; @@ -71,7 +74,7 @@ }); } - for (auto dim : packOp.getOuterDimsPerm()) + for (auto dim : packOrUnPackOp.getOuterDimsPerm()) packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim)); if (!packInfo.outerDimsOnDomainPerm.empty()) { LLVM_DEBUG({ @@ -209,6 +212,35 @@ return std::make_tuple(packedOperand, indexingMap); } +/// Pack an element-wise genericOp and return it. +static GenericOp packElementWiseOp(RewriterBase &rewriter, GenericOp genericOp, + Value dest, AffineMap packedOutIndexingMap, + const PackInfo &packInfo) { + Location loc = genericOp.getLoc(); + SmallVector inputOperands; + SmallVector indexingMaps; + for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { + auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( + rewriter, loc, packInfo, genericOp, inputOperand); + inputOperands.push_back(packedOperand); + indexingMaps.push_back(packedIndexingMap); + } + + int64_t numInnerLoops = packInfo.getNumTiledLoops(); + SmallVector iterTypes = + genericOp.getIteratorTypesArray(); + iterTypes.append(numInnerLoops, utils::IteratorType::parallel); + + indexingMaps.push_back(packedOutIndexingMap); + + auto newGenericOp = rewriter.create( + loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), + newGenericOp.getRegion().begin()); + return newGenericOp; +} + /// Bubbles up tensor.pack op through elementwise generic op. This /// swap pack(generic) to generic(pack). The new generic op works on packed /// domain; pack ops are created for input and output operands. E.g., @@ -275,29 +307,13 @@ return failure(); OpOperand *opOperand = genericOp.getDpsInitOperand(0); - auto packInfo = getPackingInfoFromConsumer( + auto packInfo = getPackingInfoFromOperand( genericOp.getMatchingIndexingMap(opOperand), packOp); - Location loc = packOp.getLoc(); - SmallVector inputOperands; - SmallVector indexingMaps; - for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { - auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( - rewriter, loc, packInfo, genericOp, inputOperand); - inputOperands.push_back(packedOperand); - indexingMaps.push_back(packedIndexingMap); - } - - int64_t numInnerLoops = packInfo.getNumTiledLoops(); - SmallVector iterTypes = - genericOp.getIteratorTypesArray(); - iterTypes.append(numInnerLoops, utils::IteratorType::parallel); - // Rebuild the indexing map for the corresponding init operand. auto [packedOutOperand, packedOutIndexingMap] = - getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp, - opOperand); - indexingMaps.push_back(packedOutIndexingMap); + getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo, + genericOp, opOperand); // We'll replace the init operand with the destination of pack op if the init // operand has not users in the body of the linalg.generic (pure elementwise). @@ -306,15 +322,12 @@ Value dest = (genericOp.getRegionOutputArgs()[0].use_empty()) ? packOp.getDest() : packedOutOperand; - auto newGenericOp = rewriter.create( - loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); - rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), - newGenericOp.getRegion().begin()); - return newGenericOp; + + return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap, + packInfo); } -// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method. +/// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method. struct BubbleUpPackOpThroughElemGenericOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -328,10 +341,134 @@ return success(); } }; + +// TODO: Relax this restriction. We should unpack an elementwise also +// in the presence of multiple unpack ops as producers. +/// Return the unpacked operand, if present, for the current generic op. +static FailureOr getUnPackedOperand(GenericOp genericOp) { + OpOperand *unPackedOperand = nullptr; + for (OpOperand &operand : genericOp->getOpOperands()) { + auto unPackOp = operand.get().getDefiningOp(); + if (!unPackOp) + continue; + if (unPackedOperand) + return failure(); + unPackedOperand = &operand; + } + if (!unPackedOperand) + return failure(); + return unPackedOperand; +} + +/// Push down a tensor.unpack op through elementwise generic op. +/// The new generic op works on packed domain; pack ops are created for input +/// and output operands. A tensor.unpack op is inserted right after the packed +/// generic. E.g. +/// +/// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +/// +/// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. +/// +/// %0 = tensor.empty() : tensor<12x56x56x64xf32> +/// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] +/// inner_dims_pos = [3] inner_tiles = [32] into %0 +/// %2 = linalg.generic {indexing_maps = [#map], +/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +/// outs(%1 : tensor<12x56x56x64xf32>) { +/// ^bb0(%out : f32): +/// linalg.yield %out : f32 +/// } -> tensor<12x56x56x64xf32> +/// +/// will be converted to +/// +/// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +/// +/// %0 = tensor.empty() : tensor<12x56x56x64xf32> +/// %1 = linalg.generic {indexing_maps = [#map], +/// iterator_types = ["parallel", "parallel", "parallel", +/// "parallel", "parallel"]} +/// outs(%arg0 : tensor<12x2x56x56x32xf32>) { +/// ^bb0(%out : f32): +/// linalg.yield %out : f32 +/// } -> tensor<12x2x56x56x32xf32> +/// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2] +/// inner_dims_pos = [3] inner_tiles = [32] into %0 +/// +static FailureOr> +pushDownUnPackOpThroughElemGenericOp(RewriterBase &rewriter, + GenericOp genericOp) { + if (!isElementwise(genericOp)) + return failure(); + if (genericOp.getNumResults() != 1) + return failure(); + + // Collect the unPacked operand, if present. + auto maybeUnPackedOperand = getUnPackedOperand(genericOp); + if (failed(maybeUnPackedOperand)) + return failure(); + OpOperand *unPackedOperand = *(maybeUnPackedOperand); + + // Extract packing information. + tensor::UnPackOp producerUnPackOp = + unPackedOperand->get().getDefiningOp(); + assert(producerUnPackOp && "expect a valid UnPackOp"); + auto packInfo = getPackingInfoFromOperand( + genericOp.getMatchingIndexingMap(unPackedOperand), producerUnPackOp); + + // Rebuild the indexing map for the corresponding init operand. + auto [packedOutOperand, packedOutIndexingMap] = + getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo, + genericOp, genericOp.getDpsInitOperand(0)); + + // If the dps init operand of the generic is a tensor.empty, do not pack it + // and forward the new tensor.empty as a destination. + Value dest = packedOutOperand; + if (auto initTensor = genericOp.getDpsInitOperand(0) + ->get() + .getDefiningOp()) { + if (auto packOp = packedOutOperand.getDefiningOp()) + dest = packOp.getDest(); + } + + // Pack the genericOp. + GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest, + packedOutIndexingMap, packInfo); + + auto unPackOp = unPackedOperand->get().getDefiningOp(); + // Insert an unPackOp right after the packed generic. + Value unPackOpRes = + rewriter + .create( + genericOp.getLoc(), + newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)), + unPackOp.getDest(), producerUnPackOp.getInnerDimsPos(), + producerUnPackOp.getMixedTiles(), + producerUnPackOp.getOuterDimsPerm()) + .getResult(); + + return std::make_tuple(newGenericOp, unPackOpRes); +} + +// Wrapper pattern that applies pushDownUnPackOpThroughElemGenericOp method. +struct PushDownUnPackOpThroughElemGenericOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + auto genericAndRepl = + pushDownUnPackOpThroughElemGenericOp(rewriter, genericOp); + if (failed(genericAndRepl)) + return failure(); + rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); + return success(); + } +}; + } // namespace void mlir::linalg::populateDataLayoutPropagationPatterns( RewritePatternSet &patterns) { - patterns.insert( - patterns.getContext()); + patterns.insert(patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -352,15 +352,123 @@ // CHECK: func.func @elem_pack_transpose_outer_dims // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> -// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] -// CHECK-SAME: into %[[ARG0_EMPTY]] // CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> // CHECK: %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]] // CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] // CHECK-SAME: into %[[ARG1_EMPTY]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG0_EMPTY]] // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] // CHECK-SAME: ins(%[[PACKED_ARG0]] // CHECK-SAME: outs(%[[PACKED_ARG1]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> { + %0 = tensor.empty() : tensor<12x56x56x64xf32> + %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> + %2 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%1 : tensor<12x56x56x64xf32>) { + ^bb0(%out: f32): + %3 = arith.addf %out, %out : f32 + linalg.yield %3 : f32 + } -> tensor<12x56x56x64xf32> + return %2 : tensor<12x56x56x64xf32> +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK: func.func @unpack_on_output +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf32> +// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]] +// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_EMPTY_PACK]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]]] +// CHECK-SAME: outs(%[[PACKED_ARG0]] +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf32>) -> tensor<12x56x56x64xf32> { + %0 = tensor.empty() : tensor<12x56x56x64xf32> + %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> + %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) { + ^bb0(%in: f32, %out: f32): + %3 = arith.addf %in, %out : f32 + linalg.yield %3 : f32 + } -> tensor<12x56x56x64xf32> + return %2 : tensor<12x56x56x64xf32> +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK: func.func @unpack_on_input +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> +// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] +// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> +// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG1_PACK_EMPTY]] +// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> +// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: ins(%[[ARG0_PACK]] +// CHECK-SAME: outs(%[[ARG1_PACK]] +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> { + %init = tensor.empty() : tensor<12x56x56x64xf32> + %0 = tensor.empty() : tensor<12x56x56x64xf32> + %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> + %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) { + ^bb0(%in: f32, %out: f32): + %3 = arith.addf %in, %in : f32 + linalg.yield %3 : f32 + } -> tensor<12x56x56x64xf32> + return %2 : tensor<12x56x56x64xf32> +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK: func.func @forward_tensor_empty +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> +// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] +// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> +// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: ins(%[[PACKED_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]