diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3467,29 +3467,11 @@ /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on /// the cast source operand type and the SubViewOp static information. This /// is the resulting type if the MemRefCastOp were folded. - auto resultType = SubViewOp::inferResultType( - castOp.source().getType().cast(), - extractFromI64ArrayAttr(subViewOp.static_offsets()), - extractFromI64ArrayAttr(subViewOp.static_sizes()), - extractFromI64ArrayAttr(subViewOp.static_strides())) - .cast(); - uint32_t rankDiff = - subViewOp.getSourceType().getRank() - subViewOp.getType().getRank(); - if (rankDiff > 0) { - auto shape = resultType.getShape(); - auto projectedShape = shape.drop_front(rankDiff); - AffineMap map; - auto maps = resultType.getAffineMaps(); - if (!maps.empty() && maps.front()) { - auto optionalUnusedDimsMask = - computeRankReductionMask(shape, projectedShape); - llvm::SmallDenseSet dimsToProject = - optionalUnusedDimsMask.getValue(); - map = getProjectedMap(maps.front(), dimsToProject); - } - resultType = MemRefType::get(projectedShape, resultType.getElementType(), - map, resultType.getMemorySpace()); - } + auto resultType = SubViewOp::inferRankReducedResultType( + subViewOp.getType().getRank(), + castOp.source().getType().cast(), + subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), + subViewOp.getMixedStrides()); Value newSubView = rewriter.create( subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), @@ -3703,11 +3685,67 @@ return produceSubViewErrorMsg(result, op, expectedType); } +namespace { +/// Pattern to rewrite a subview op with MemRefCast arguments. +/// This essentially pushes memref_cast past its consuming subview when +/// `canFoldIntoConsumerOp` is true. +/// +/// Example: +/// ``` +/// %0 = memref_cast %V : memref<16x16xf32> to memref +/// %1 = subview %0[0, 0][3, 4][1, 1] : +/// memref to memref<3x4xf32, offset:?, strides:[?, 1]> +/// ``` +/// is rewritten into: +/// ``` +/// %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> +/// %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to +/// memref<3x4xf32, offset:?, strides:[?, 1]> +/// ``` +class SubTensorOpCastFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubTensorOp subTensorOp, + PatternRewriter &rewriter) const override { + // Any constant operand, just return to let SubViewOpConstantFolder kick in. + if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return failure(); + + auto castOp = subTensorOp.source().getDefiningOp(); + if (!castOp) + return failure(); + + if (!canFoldIntoConsumerOp(castOp)) + return failure(); + + /// Deduce the resultType of SubTensorOp with `inferRankReducedResultType` + /// on the cast source operand type and the SubTensorOp static information. + /// This is the resulting type if the tensor::CastOp were folded and + /// rank-reduced to the desired result rank. + auto resultType = SubTensorOp::inferRankReducedResultType( + subTensorOp.getType().getRank(), + castOp.source().getType().cast(), + subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(), + subTensorOp.getMixedStrides()); + Value newSubTensor = rewriter.create( + subTensorOp.getLoc(), resultType, castOp.source(), + subTensorOp.offsets(), subTensorOp.sizes(), subTensorOp.strides(), + subTensorOp.static_offsets(), subTensorOp.static_sizes(), + subTensorOp.static_strides()); + rewriter.replaceOpWithNewOp( + subTensorOp, subTensorOp.getType(), newSubTensor); + return success(); + } +}; +} // namespace + void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results - .insert>( - context); + results.insert, + SubTensorOpCastFolder>(context); } // diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -146,13 +146,13 @@ // CHECK-LABEL: func @subview_of_memcast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> -// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> +// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> // CHECK: %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> // CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref - %1 = subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : + %1 = subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref to memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> @@ -176,16 +176,14 @@ return %0 : tensor<4x6x16x32xi8> } -// CHECK-LABEL: func @rank_reducing_subview_of_memcast -// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> -// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> -// CHECK: %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> -// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> -func @rank_reducing_subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> - memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ - %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref - %1 = subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : - memref to - memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> - return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> +// CHECK-LABEL: func @rank_reducing_tensor_of_cast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> +// CHECK: %[[S:.+]] = subtensor %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8> +// Tensor cast is moved after subtensor and then gets canonicalized away. +// CHECK-NOT: tensor.cast +// CHECK: return %[[S]] : tensor<16x32xi8> +func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> { + %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor + %1 = subtensor %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor to tensor<16x32xi8> + return %1 : tensor<16x32xi8> } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1034,8 +1034,8 @@ // CHECK-LABEL: func @subtensor // CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index -func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index) - -> tensor +func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index) + -> tensor { %c0 = constant 0 : index %c1 = constant 1 : index @@ -1045,16 +1045,18 @@ // CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] : // CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32> - // CHECK: tensor.cast %{{.*}} : tensor<7x11x2xf32> to tensor + // tensor.cast gets folded away in consumer. + // CHECK-NOT: tensor.cast %1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] : tensor<8x16x4xf32> to tensor // Test: subtensor with one dynamic operand can also be folded. // CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] : - // CHECK-SAME: tensor to tensor<2x?x2xf32> + // CHECK-SAME: tensor<7x11x2xf32> to tensor<2x?x2xf32> // CHECK: tensor.cast %{{.*}} : tensor<2x?x2xf32> to tensor %2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1] : tensor to tensor return %2 : tensor } +