diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1241,66 +1241,124 @@ /*nofold=*/false, loc, builder); } +// Normalizes a permutation on a higher rank space to its actual size, e.g. +// perm = [1, 4, 2] +// becomes +// norm = [0, 2, 1] static SmallVector -getPackUnpackNormalizedInnerPerm(int rank, ArrayRef innerDimsPos) { +getPackUnpackNormalizedPerm(int rank, ArrayRef perm) { constexpr int64_t kNonTiledMarker = -1; SmallVector vec(rank, kNonTiledMarker); - for (auto [index, value] : llvm::enumerate(innerDimsPos)) + for (auto [index, value] : llvm::enumerate(perm)) vec[value] = index; - SmallVector perm = llvm::to_vector(llvm::make_filter_range( + SmallVector normalizedPerm = llvm::to_vector(llvm::make_filter_range( vec, [&](int64_t v) { return v != kNonTiledMarker; })); + // This inverts the permutation in addition to normalizing so invert back. + return invertPermutationVector(normalizedPerm); +} + +// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm +// assuming rank reduction of unit outer dims. +static SmallVector +getPackUnpackRankReducedPerm(ArrayRef shape, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + SmallVector rankReducedOuterDimsPerm; + SmallVector outerDims; + SmallVector innerDims; + int64_t dim = 0; + int64_t unpackedRank = shape.size(); + for (auto i : llvm::seq(0, unpackedRank)) { + if (llvm::is_contained(innerDimsPos, i)) { + innerDims.push_back(dim++); + continue; + } + if (shape[i] == 1) + continue; + outerDims.push_back(dim++); + if (!outerDimsPerm.empty()) + rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]); + } + + // Get the position of the inner dims after permutation. + SmallVector innerPerm = + getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos); + applyPermutationToVector(innerDims, innerPerm); + + // Ditto for the outer dims. + SmallVector perm = outerDims; + + rankReducedOuterDimsPerm = + getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm); + if (!rankReducedOuterDimsPerm.empty()) + applyPermutationToVector(perm, rankReducedOuterDimsPerm); + + // The tile always ends up as the inner most dims after packing. + perm.append(innerDims); + return perm; } LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( tensor::PackOp packOp, PatternRewriter &rewriter) const { - // TODO: support the case that outer dimensions are not all 1s A - // tensor.expand_shape will be generated in this case. - int64_t srcRank = packOp.getSourceRank(); - if (llvm::any_of(packOp.getDestType().getShape().take_front(srcRank), - [](int64_t val) { return val != 1; })) { - return rewriter.notifyMatchFailure( - packOp, "require the outer dimension of the result are all 1s"); - } - if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) { return tile.is(); })) { return rewriter.notifyMatchFailure(packOp, "require inner tile sizes being static"); } - // 1. Use rank-reduced tensor.extract_slice op to extract the tile. + // TODO: support the case that outer dimensions are not all 1s. A + // tensor.expand_shape will be generated in this case. + auto innerDimsPos = packOp.getInnerDimsPos(); + int64_t srcRank = packOp.getSourceRank(); + auto destShape = packOp.getDestType().getShape(); + if (llvm::any_of(innerDimsPos, [destShape](int64_t index) { + return destShape[index] != 1; + })) { + return rewriter.notifyMatchFailure( + packOp, "require the tiled outer dimensions of the result are all 1s"); + } + + // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled + // outer dims. Location loc = packOp.getLoc(); + Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); + auto inputShape = packOp.getSourceType().getShape(); + DenseMap dimAndTileMapping = + packOp.getDimAndTileMapping(); Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); SmallVector readOffsets(srcRank, zeroIdxAttr); SmallVector readStrides(srcRank, oneIdxAttr); SmallVector readSizes; SmallVector readShape; - DenseMap dimAndTileMapping = - packOp.getDimAndTileMapping(); for (auto i : llvm::seq(0, srcRank)) { - if (!dimAndTileMapping.count(i)) { - readSizes.push_back(oneIdxAttr); + if (dimAndTileMapping.count(i)) { + readShape.push_back(getConstantIntValue(dimAndTileMapping[i]) + .value_or(ShapedType::kDynamic)); + readSizes.push_back(dimAndTileMapping[i]); continue; } - readSizes.push_back(dimAndTileMapping[i]); - readShape.push_back(getConstantIntValue(dimAndTileMapping[i]) - .value_or(ShapedType::kDynamic)); + if (ShapedType::isDynamic(inputShape[i])) { + readSizes.push_back( + rewriter.create(loc, input, i).getResult()); + } else { + readSizes.push_back(rewriter.getIndexAttr(inputShape[i])); + } + if (inputShape[i] != 1) + readShape.push_back(inputShape[i]); } + Type elemType = packOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShape, elemType); - Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); Value tile = rewriter.create( loc, readType, input, readOffsets, readSizes, readStrides); // 2. Transpose the tile to match the inner tile order. - SmallVector perm = - getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos()); - // The permutation is inverted when normalizing so invert back to match the - // ordering in the pack op. - perm = invertPermutationVector(perm); + + SmallVector perm = getPackUnpackRankReducedPerm( + inputShape, innerDimsPos, packOp.getOuterDimsPerm()); LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL();); @@ -1316,9 +1374,8 @@ int64_t destRank = packOp.getDestRank(); SmallVector writeStrides(destRank, oneIdxAttr); SmallVector writeOffsets(destRank, zeroIdxAttr); - SmallVector writeSizes(srcRank, oneIdxAttr); - for (auto size : transpShape) - writeSizes.push_back(rewriter.getIndexAttr(size)); + SmallVector writeSizes = + tensor::getMixedSizes(rewriter, loc, packOp.getDest()); auto insert = rewriter.create( loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, @@ -1333,35 +1390,59 @@ int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); ArrayRef srcShape = unpackOp.getSourceType().getShape(); - if (llvm::any_of(srcShape.take_front(destRank), - [](int64_t val) { return val != 1; })) { + ArrayRef innerDimsPos = unpackOp.getInnerDimsPos(); + if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) { + return srcShape[index] != 1; + })) { return rewriter.notifyMatchFailure( - unpackOp, "require the outer dimension of the result are all 1s"); + unpackOp, + "require the tiled outer dimensions of the result are all 1s"); } // 1. Use rank-reduced tensor.extract_slice op to extract the tile. Location loc = unpackOp.getLoc(); + Value source = unpackOp.getSource(); + DenseMap dimAndTileMapping = + unpackOp.getDimAndTileMapping(); Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); SmallVector readOffsets(srcRank, zeroIdxAttr); SmallVector readStrides(srcRank, oneIdxAttr); + SmallVector readSizes; + SmallVector readShape; + for (auto i : llvm::seq(0, destRank)) { + if (dimAndTileMapping.count(i)) { + readSizes.push_back(oneIdxAttr); + continue; + } + if (ShapedType::isDynamic(srcShape[i])) { + readSizes.push_back( + rewriter.create(loc, source, i).getResult()); + } else { + readSizes.push_back(rewriter.getIndexAttr(srcShape[i])); + } + if (srcShape[i] != 1) + readShape.push_back(srcShape[i]); + } auto mixedTiles = unpackOp.getMixedTiles(); - SmallVector readSizes(destRank, oneIdxAttr); readSizes.append(mixedTiles.begin(), mixedTiles.end()); // Explicitly create the type for extract_slice op because the inner tile // size could be 1. We want to represent the whole inner tile in this case. - ArrayRef readShape = srcShape.drop_front(destRank); + auto tileShape = srcShape.drop_front(destRank); + // Append the inner tile shape to the permuted and rank-reduced outer shape. + readShape.append(tileShape.begin(), tileShape.end()); Type elemType = unpackOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShape, elemType); Value innerTile = rewriter.create( loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides); // 2. Transpose the tile to match the outer corresponding tile order. - ArrayRef innerDimsPos = unpackOp.getInnerDimsPos(); - SmallVector perm = - getPackUnpackNormalizedInnerPerm(srcRank, innerDimsPos); + SmallVector perm = getPackUnpackRankReducedPerm( + srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm()); + // Unpack is a transition out of packed space so we invert the permutation. + perm = invertPermutationVector(perm); SmallVector transpShape(readShape); applyPermutationToVector(transpShape, perm); @@ -1375,11 +1456,13 @@ SmallVector tileStrides(numLoops, oneIdxAttr); SmallVector tileOffsets(numLoops, zeroIdxAttr); SmallVector tileSizes; - for (int dim : innerDimsPos) - tileSizes.push_back(getAsOpFoldResult( - rewriter.createOrFold(loc, unpackOp.getDest(), dim))); + ArrayRef destShape = unpackOp.getDestType().getShape(); + for (auto i : llvm::seq(0, destRank)) { + if (dimAndTileMapping.count(i) || destShape[i] != 1) + tileSizes.push_back(getAsOpFoldResult( + rewriter.createOrFold(loc, unpackOp.getDest(), i))); + } - applyPermutationToVector(tileSizes, perm); auto partialTile = rewriter.create( loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); @@ -1387,10 +1470,8 @@ SmallVector writeSizes; SmallVector writeStrides(destRank, oneIdxAttr); SmallVector writeOffsets(destRank, zeroIdxAttr); - DenseMap dimAndTileMapping = - unpackOp.getDimAndTileMapping(); for (int i = 0, idx = 0; i < destRank; ++i) { - if (dimAndTileMapping.count(i)) + if (dimAndTileMapping.count(i) || destShape[i] != 1) writeSizes.push_back(tileSizes[idx++]); else writeSizes.push_back(oneIdxAttr); diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir --- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir @@ -76,3 +76,22 @@ // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 5, 7, 3] [1, 1, 1, 1, 1, 1] // CHECK: return %[[INSERT]] + +// ----- + +func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1: tensor<3x1x1x1x8x32xf32>) -> tensor<3x1x1x1x8x32xf32> { + %0 = tensor.pack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32> + return %0 : tensor<3x1x1x1x8x32xf32> +} +// CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [3, 1, 32, 8] [1, 1, 1, 1] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x8x32xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] : tensor<3x32x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<3x8x32xf32>) +// CHECK-SAME: permutation = [0, 2, 1] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0, 0] [3, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] +// CHECK: return %[[INSERT]] diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir --- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir @@ -55,3 +55,42 @@ // They have the same type, so the insert_slice op is folded // away. // CHECK: return %[[TRANSP]] + +// ----- + +func.func @simple_NCHWc_to_NCHW(%arg0: tensor<2x1x16x8x32xf32>, %arg1: tensor<2x32x16x8xf32>) -> tensor<2x32x16x8xf32> { + %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %arg1 : tensor<2x1x16x8x32xf32> -> tensor<2x32x16x8xf32> + return %0 : tensor<2x32x16x8xf32> +} +// CHECK-LABEL: func.func @simple_NCHWc_to_NCHW +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x32x16x8xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] : tensor<2x16x8x32xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x32x16x8xf32>) +// CHECK-SAME: permutation = [0, 3, 1, 2] +// They have the same type, so the insert_slice op is folded +// away. +// CHECK: return %[[TRANSP]] + + +// ----- + +func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x16x8xf32>) -> tensor<1x32x16x8xf32> { + %0 = tensor.unpack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [] inner_tiles = [] into %arg1 : tensor<1x16x8x32xf32> -> tensor<1x32x16x8xf32> + return %0 : tensor<1x32x16x8xf32> +} +// CHECK-LABEL: func.func @simple_NHWC_to_NCHW +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 16, 8, 32] [1, 1, 1, 1] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x8xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] : tensor<16x8x32xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x16x8xf32>) +// CHECK-SAME: permutation = [2, 0, 1] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1] +// CHECK: return %[[INSERT]]