diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -35,7 +35,7 @@ llvm::SmallDenseSet &dimsToProject); /// Pattern to rewrite a subview op with constant arguments. -template +template class OpWithOffsetSizesAndStridesConstantArgumentFolder final : public OpRewritePattern { public: @@ -59,8 +59,12 @@ canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. - auto newOp = rewriter.create(op.getLoc(), op.source(), mixedOffsets, - mixedSizes, mixedStrides); + ResultTypeFunc resultTypeFunc; + auto resultType = + resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides); + auto newOp = + rewriter.create(op.getLoc(), resultType, op.source(), + mixedOffsets, mixedSizes, mixedStrides); CastOpFunc func; func(rewriter, op, newOp); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1859,6 +1859,26 @@ return res; } +/// Infer the canonical type of the result of a subtensor operation. Returns a +/// type with rank `resultRank` that is either the rank of the rank-reduced +/// type, or the non-rank-reduced type. +static MemRefType +getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, + ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides) { + auto resultType = + SubViewOp::inferRankReducedResultType( + resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) + .cast(); + if (resultType.getRank() != resultRank) { + resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, + mixedSizes, mixedStrides) + .cast(); + } + return resultType; +} + namespace { /// Pattern to rewrite a subview op with MemRefCast arguments. /// This essentially pushes memref.cast past its consuming subview when @@ -1898,7 +1918,7 @@ /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on /// the cast source operand type and the SubViewOp static information. This /// is the resulting type if the MemRefCastOp were folded. - auto resultType = SubViewOp::inferRankReducedResultType( + auto resultType = getCanonicalSubViewResultType( subViewOp.getType().getRank(), castOp.source().getType().cast(), subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), @@ -1914,6 +1934,17 @@ }; } // namespace +/// Return the canonical type of the result of a subview. +struct SubViewReturnTypeCanonicalizer { + MemRefType operator()(SubViewOp op, ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides) { + return getCanonicalSubViewResultType(op.getType().getRank(), + op.getSourceType(), mixedOffsets, + mixedSizes, mixedStrides); + } +}; + /// A canonicalizer wrapper to replace SubViewOps. struct SubViewCanonicalizer { void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { @@ -1923,9 +1954,10 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - SubViewOpMemRefCastFolder>(context); + results + .add, + SubViewOpMemRefCastFolder>(context); } OpFoldResult SubViewOp::fold(ArrayRef operands) { diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp @@ -45,10 +45,13 @@ // the subview op with load even if the offsets have been canonicalized // away. SmallVector opRanges = subViewOp.getOrCreateRanges(rewriter, loc); + if (opRanges.size() != indices.size()) { + // For the rank-reduced cases, we can only handle the folding when the + // offset is zero, size is 1 and stride is 1. + return failure(); + } auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; }); auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; }); - assert(opRanges.size() == indices.size() && - "expected as many indices as rank of subview op result type"); // New indices for the load are the current indices * subview_stride + // subview_offset. diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1917,6 +1917,25 @@ return produceSubTensorErrorMsg(result, op, expectedType); } +/// Infer the canonical type of the result of a subtensor operation. Returns a +/// type with rank `resultRank` that is either the rank of the rank-reduced +/// type, or the non-rank-reduced type. +static RankedTensorType getCanonicalSubTensorResultType( + unsigned resultRank, RankedTensorType sourceType, + ArrayRef mixedOffsets, ArrayRef mixedSizes, + ArrayRef mixedStrides) { + auto resultType = + SubTensorOp::inferRankReducedResultType( + resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) + .cast(); + if (resultType.getRank() != resultRank) { + resultType = SubTensorOp::inferResultType(sourceType, mixedOffsets, + mixedSizes, mixedStrides) + .cast(); + } + return resultType; +} + namespace { /// Pattern to rewrite a subtensor op with tensor::Cast arguments. /// This essentially pushes memref_cast past its consuming subtensor when @@ -1955,9 +1974,11 @@ /// on the cast source operand type and the SubTensorOp static information. /// This is the resulting type if the tensor::CastOp were folded and /// rank-reduced to the desired result rank. - auto resultType = SubTensorOp::inferRankReducedResultType( - subTensorOp.getType().getRank(), - castOp.source().getType().cast(), + + // If the rank of the infered result and subtensor op are different, the + // `tensor.cast` below will be illegal. + RankedTensorType resultType = getCanonicalSubTensorResultType( + subTensorOp.getType().getRank(), subTensorOp.getSourceType(), subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides()); Value newSubTensor = rewriter.create( @@ -1972,6 +1993,18 @@ }; } // namespace +/// Return the canonical type of the result of a subtensor. +struct SubTensorReturnTypeCanonicalizer { + RankedTensorType operator()(SubTensorOp op, + ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides) { + return getCanonicalSubTensorResultType(op.getType().getRank(), + op.getSourceType(), mixedOffsets, + mixedSizes, mixedStrides); + } +}; + /// A canonicalizer wrapper to replace SubTensorOps. struct SubTensorCanonicalizer { void operator()(PatternRewriter &rewriter, SubTensorOp op, @@ -1987,7 +2020,8 @@ void SubTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, + SubTensorOp, SubTensorReturnTypeCanonicalizer, + SubTensorCanonicalizer>, SubTensorOpCastFolder>(context); } @@ -2093,22 +2127,9 @@ canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. - Value source = subTensorInsertOp.source(); - RankedTensorType sourceType = source.getType().cast(); - SmallVector shape = llvm::to_vector<4>( - llvm::map_range(mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t { - if (auto attr = valueOrAttr.dyn_cast()) - return attr.cast().getInt(); - return ShapedType::kDynamicSize; - })); - RankedTensorType newSourceType = - RankedTensorType::get(shape, sourceType.getElementType()); - Location loc = subTensorInsertOp.getLoc(); - if (sourceType != newSourceType) - source = rewriter.create(loc, newSourceType, source); rewriter.replaceOpWithNewOp( - subTensorInsertOp, source, subTensorInsertOp.dest(), mixedOffsets, - mixedSizes, mixedStrides); + subTensorInsertOp, subTensorInsertOp.source(), subTensorInsertOp.dest(), + mixedOffsets, mixedSizes, mixedStrides); return success(); } }; @@ -2213,7 +2234,6 @@ SmallVectorImpl &caseOperands, SmallVectorImpl &caseOperandTypes, DenseIntElementsAttr &caseOperandOffsets) { - if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) || failed(parser.parseSuccessor(defaultDestination))) return failure(); @@ -2457,7 +2477,6 @@ /// ] static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter) { - SmallVector newCaseDests; SmallVector newCaseOperands; SmallVector> argStorage; diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -62,3 +62,70 @@ %1 = memref.buffer_cast %0 : memref return %1 : memref } + +// ----- + +// CHECK-LABEL: func @subview_of_memcast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> +// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> +// CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> +// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> +func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ + %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref + %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : + memref to + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> + return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> +} + +// ----- + +// CHECK-LABEL: func @subview_of_static_full_size +// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> +// CHECK-NOT: memref.subview +// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8> +func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> { + %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8> + return %0 : memref<4x6x16x32xi8> +} + +// ----- + +#map0 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +func @subview_canonicalize(%arg0 : memref, %arg1 : index, + %arg2 : index) -> memref +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : memref to memref + return %0 : memref +} +// CHECK-LABEL: func @subview_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: memref +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : memref to memref<4x1x?xf32 +// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] +// CHEKC: return %[[RESULT]] + +// ----- + +#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +func @rank_reducing_subview_canonicalize(%arg0 : memref, %arg1 : index, + %arg2 : index) -> memref +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref to memref + return %0 : memref +} +// CHECK-LABEL: func @rank_reducing_subview_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: memref +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : memref to memref<4x?xf32 +// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] +// CHEKC: return %[[RESULT]] diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -154,30 +154,41 @@ // ----- -// CHECK-LABEL: func @subview_of_memcast -// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> -// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> -// CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> -// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> -func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> - memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ - %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref - %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : - memref to - memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> - return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> +func @subtensor_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor to tensor + return %0 : tensor } +// CHECK-LABEL: func @subtensor_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x1x?xf32> +// CHECK: %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]] +// CHEKC: return %[[RESULT]] // ----- -// CHECK-LABEL: func @subview_of_static_full_size -// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> -// CHECK-NOT: memref.subview -// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8> -func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> { - %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8> - return %0 : memref<4x6x16x32xi8> +func @rank_reducing_subtensor_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor to tensor + return %0 : tensor } +// CHECK-LABEL: func @rank_reducing_subtensor_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x?xf32> +// CHECK: %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]] +// CHEKC: return %[[RESULT]] // ----- @@ -232,7 +243,89 @@ // ----- -func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor, +func @subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %0 : tensor +} +// CHECK-LABEL: func @subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor to tensor + %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @subtensor_to_subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x1x?xf32> +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor<4x1x?xf32> into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @rank_reducing_subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %0 : tensor +} +// CHECK-LABEL: func @rank_reducing_subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @rank_reducing_subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor to tensor + %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @rank_reducing_subtensor_to_subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x?xf32> +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]] into %[[ARG3]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor<4x?xf32> into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor, %arg2 : index, %arg3 : index) -> tensor { %c0 = constant 0 : index %c1 = constant 1 : index @@ -247,7 +340,7 @@ %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor return %3 : tensor } -// CHECK-LABEL: func @subtensor_canonicalize +// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast // CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1] // CHECK-SAME: tensor<2x?xi32> into tensor // CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]