diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2788,6 +2788,16 @@ ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); + static Type inferRankReducedResultType(unsigned resultRank, + MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static Type inferRankReducedResultType(unsigned resultRank, + MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. @@ -2914,6 +2924,16 @@ ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); + static Type inferRankReducedResultType(unsigned resultRank, + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static Type inferRankReducedResultType(unsigned resultRank, + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. @@ -3027,6 +3047,7 @@ /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; } }]; + let hasFolder = 1; } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2891,8 +2891,68 @@ dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, - staticSizes, staticStrides) - .cast(); + staticSizes, staticStrides); +} + +static void +getPositionsOfShapeOne(unsigned rank, ArrayRef shape, + llvm::SmallDenseSet &dimsToProject) { + dimsToProject.reserve(rank); + for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) { + if (shape[pos] == 1) { + dimsToProject.insert(pos); + --rank; + } + } +} + +Type SubViewOp::inferRankReducedResultType( + unsigned resultRank, MemRefType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { + auto inferredType = + inferResultType(sourceRankedTensorType, leadingStaticOffsets, + leadingStaticSizes, leadingStaticStrides) + .cast(); + assert(inferredType.getRank() >= resultRank && "expected "); + int rankDiff = inferredType.getRank() - resultRank; + if (rankDiff > 0) { + auto shape = inferredType.getShape(); + llvm::SmallDenseSet dimsToProject; + getPositionsOfShapeOne(rankDiff, shape, dimsToProject); + SmallVector projectedShape; + for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) + if (!dimsToProject.contains(pos)) + projectedShape.push_back(shape[pos]); + + AffineMap map; + auto maps = inferredType.getAffineMaps(); + if (!maps.empty() && maps.front()) + map = getProjectedMap(maps.front(), dimsToProject); + inferredType = + MemRefType::get(projectedShape, inferredType.getElementType(), map, + inferredType.getMemorySpace()); + } + return inferredType; +} + +Type SubViewOp::inferRankReducedResultType( + unsigned resultRank, MemRefType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { + SmallVector staticOffsets, staticSizes, staticStrides; + SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; + dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, + staticOffsets, ShapedType::kDynamicStrideOrOffset); + dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, + ShapedType::kDynamicSize); + dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, + staticStrides, ShapedType::kDynamicStrideOrOffset); + return SubViewOp::inferRankReducedResultType( + resultRank, sourceRankedTensorType, staticOffsets, staticSizes, + staticStrides); } // Build a SubViewOp with mixed static and dynamic entries and custom result @@ -3407,29 +3467,11 @@ /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on /// the cast source operand type and the SubViewOp static information. This /// is the resulting type if the MemRefCastOp were folded. - auto resultType = SubViewOp::inferResultType( - castOp.source().getType().cast(), - extractFromI64ArrayAttr(subViewOp.static_offsets()), - extractFromI64ArrayAttr(subViewOp.static_sizes()), - extractFromI64ArrayAttr(subViewOp.static_strides())) - .cast(); - uint32_t rankDiff = - subViewOp.getSourceType().getRank() - subViewOp.getType().getRank(); - if (rankDiff > 0) { - auto shape = resultType.getShape(); - auto projectedShape = shape.drop_front(rankDiff); - AffineMap map; - auto maps = resultType.getAffineMaps(); - if (!maps.empty() && maps.front()) { - auto optionalUnusedDimsMask = - computeRankReductionMask(shape, projectedShape); - llvm::SmallDenseSet dimsToProject = - optionalUnusedDimsMask.getValue(); - map = getProjectedMap(maps.front(), dimsToProject); - } - resultType = MemRefType::get(projectedShape, resultType.getElementType(), - map, resultType.getMemorySpace()); - } + auto resultType = SubViewOp::inferRankReducedResultType( + subViewOp.getType().getRank(), + castOp.source().getType().cast(), + subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), + subViewOp.getMixedStrides()); Value newSubView = rewriter.create( subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), @@ -3439,12 +3481,40 @@ return success(); } }; + +// Rewrite memref_cast + subview as subview + memref_cast. +struct SubViewFromMemRefCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const final { + auto memrefCastOp = subViewOp.source().getDefiningOp(); + if (!memrefCastOp || !canFoldIntoConsumerOp(memrefCastOp)) + return failure(); + + // Let type inference do the work for us. + auto srcMemrefType = memrefCastOp.source().getType().cast(); + auto dstMemrefType = subViewOp.getType(); + auto subviewMemRefType = + SubViewOp::inferRankReducedResultType( + dstMemrefType.getRank(), srcMemrefType, subViewOp.getMixedOffsets(), + subViewOp.getMixedSizes(), subViewOp.getMixedStrides()) + .cast(); + Value newSubview = rewriter.create( + subViewOp.getLoc(), subviewMemRefType, memrefCastOp.source(), + subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), + subViewOp.getMixedStrides()); + rewriter.replaceOpWithNewOp(subViewOp, subViewOp.getType(), + newSubview); + return success(); + } +}; } // namespace void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert, - SubViewOpMemRefCastFolder>(context); + SubViewFromMemRefCast>(context); } OpFoldResult SubViewOp::fold(ArrayRef operands) { @@ -3492,8 +3562,52 @@ dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); return SubTensorOp::inferResultType(sourceRankedTensorType, staticOffsets, - staticSizes, staticStrides) - .cast(); + staticSizes, staticStrides); +} + +/// A subtensor result type can be fully inferred from the source type and the +/// static representation of offsets, sizes and strides. Special sentinels +/// encode the dynamic case. +Type SubTensorOp::inferRankReducedResultType( + unsigned resultRank, RankedTensorType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { + auto inferredType = + inferResultType(sourceRankedTensorType, leadingStaticOffsets, + leadingStaticSizes, leadingStaticStrides) + .cast(); + int rankDiff = inferredType.getRank() - resultRank; + if (rankDiff > 0) { + auto shape = inferredType.getShape(); + llvm::SmallDenseSet dimsToProject; + getPositionsOfShapeOne(rankDiff, shape, dimsToProject); + SmallVector projectedShape; + for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) + if (!dimsToProject.contains(pos)) + projectedShape.push_back(shape[pos]); + inferredType = + RankedTensorType::get(projectedShape, inferredType.getElementType()); + } + return inferredType; +} + +Type SubTensorOp::inferRankReducedResultType( + unsigned resultRank, RankedTensorType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { + SmallVector staticOffsets, staticSizes, staticStrides; + SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; + dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, + staticOffsets, ShapedType::kDynamicStrideOrOffset); + dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, + ShapedType::kDynamicSize); + dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, + staticStrides, ShapedType::kDynamicStrideOrOffset); + return SubTensorOp::inferRankReducedResultType( + resultRank, sourceRankedTensorType, staticOffsets, staticSizes, + staticStrides); } // Build a SubTensorOp with mixed static and dynamic entries and custom result @@ -3571,11 +3685,67 @@ return produceSubViewErrorMsg(result, op, expectedType); } +namespace { +/// Pattern to rewrite a subview op with MemRefCast arguments. +/// This essentially pushes memref_cast past its consuming subview when +/// `canFoldIntoConsumerOp` is true. +/// +/// Example: +/// ``` +/// %0 = memref_cast %V : memref<16x16xf32> to memref +/// %1 = subview %0[0, 0][3, 4][1, 1] : +/// memref to memref<3x4xf32, offset:?, strides:[?, 1]> +/// ``` +/// is rewritten into: +/// ``` +/// %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> +/// %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to +/// memref<3x4xf32, offset:?, strides:[?, 1]> +/// ``` +class SubTensorOpCastFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubTensorOp subTensorOp, + PatternRewriter &rewriter) const override { + // Any constant operand, just return to let SubViewOpConstantFolder kick in. + if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return failure(); + + auto castOp = subTensorOp.source().getDefiningOp(); + if (!castOp) + return failure(); + + if (!canFoldIntoConsumerOp(castOp)) + return failure(); + + /// Deduce the resultType of SubTensorOp with `inferRankReducedResultType` + /// on the cast source operand type and the SubTensorOp static information. + /// This is the resulting type if the tensor::CastOp were folded and + /// rank-reduced to the desired result rank. + auto resultType = SubTensorOp::inferRankReducedResultType( + subTensorOp.getType().getRank(), + castOp.source().getType().cast(), + subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(), + subTensorOp.getMixedStrides()); + Value newSubTensor = rewriter.create( + subTensorOp.getLoc(), resultType, castOp.source(), + subTensorOp.offsets(), subTensorOp.sizes(), subTensorOp.strides(), + subTensorOp.static_offsets(), subTensorOp.static_sizes(), + subTensorOp.static_strides()); + rewriter.replaceOpWithNewOp( + subTensorOp, subTensorOp.getType(), newSubTensor); + return success(); + } +}; +} // namespace + void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results - .insert>( - context); + results.insert, + SubTensorOpCastFolder>(context); } // diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -146,13 +146,13 @@ // CHECK-LABEL: func @subview_of_memcast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> -// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> +// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> // CHECK: %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> // CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref - %1 = subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : + %1 = subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref to memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> @@ -176,3 +176,14 @@ return %0 : tensor<4x6x16x32xi8> } +// CHECK-LABEL: func @rank_reducing_tensor_of_cast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> +// CHECK: %[[S:.+]] = subtensor %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8> +// Tensor cast is moved after subtensor and then gets canonicalized away. +// CHECK-NOT: tensor.cast +// CHECK: return %[[S]] : tensor<16x32xi8> +func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> { + %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor + %1 = subtensor %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor to tensor<16x32xi8> + return %1 : tensor<16x32xi8> +} diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1034,8 +1034,8 @@ // CHECK-LABEL: func @subtensor // CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index -func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index) - -> tensor +func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index) + -> tensor { %c0 = constant 0 : index %c1 = constant 1 : index @@ -1045,16 +1045,18 @@ // CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] : // CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32> - // CHECK: tensor.cast %{{.*}} : tensor<7x11x2xf32> to tensor + // tensor.cast gets folded away in consumer. + // CHECK-NOT: tensor.cast %1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] : tensor<8x16x4xf32> to tensor // Test: subtensor with one dynamic operand can also be folded. // CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] : - // CHECK-SAME: tensor to tensor<2x?x2xf32> + // CHECK-SAME: tensor<7x11x2xf32> to tensor<2x?x2xf32> // CHECK: tensor.cast %{{.*}} : tensor<2x?x2xf32> to tensor %2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1] : tensor to tensor return %2 : tensor } +