diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -35,7 +35,7 @@ llvm::SmallDenseSet &dimsToProject); /// Pattern to rewrite a subview op with constant arguments. -template +template class OpWithOffsetSizesAndStridesConstantArgumentFolder final : public OpRewritePattern { public: @@ -59,8 +59,12 @@ canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. - auto newOp = rewriter.create(op.getLoc(), op.source(), mixedOffsets, - mixedSizes, mixedStrides); + ResultTypeFunc resultTypeFunc; + auto resultType = + resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides); + auto newOp = + rewriter.create(op.getLoc(), resultType, op.source(), + mixedOffsets, mixedSizes, mixedStrides); CastOpFunc func; func(rewriter, op, newOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -544,77 +544,76 @@ return success(); } }; +} // namespace -/// Pattern to fold subtensors that are just taking a slice of unit-dimension -/// tensor. For example -/// -/// %1 = subtensor %0[0, %o1, 0] [1, %s1, 1] [1, 1, 1] -/// : tensor<1x?x1xf32> to tensor<1x?x1xf32> -/// -/// can be replaced with -/// -/// %0 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] -/// : tensor<1x?x1xf32> into tensor -/// %1 = subtensor %0[%o1] [%s1] [1] : tensor to tensor -/// %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] -/// : tensor into tensor<1x?x1xf32> -/// -/// The additional tensor_reshapes will hopefully get canonicalized away with -/// other reshapes that drop unit dimensions. Three condiitions to fold a -/// dimension -/// - The offset must be 0 -/// - The size must be 1 -/// - The dimension of the source type must be 1. -struct FoldUnitDimSubTensorOp : public OpRewritePattern { +/// Get the reassociation maps to convert a `type` to its rank-reduced version. +static Optional> +getReassociationMapForFoldingUnitDims(ShapedType type) { + auto shape = type.getShape(); + SmallVector reassociation; + ReassociationIndices curr; + for (auto pos : llvm::enumerate(shape)) { + curr.push_back(pos.index()); + if (pos.value() == 1) + continue; + reassociation.emplace_back(std::move(curr)); + curr.clear(); + } + if (!curr.empty()) + reassociation.back().append(curr.begin(), curr.end()); + return reassociation; +} + +namespace { +/// Convert `subtensor` operations to rank-reduced versions. +struct UseRankReducedSubTensorOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SubTensorOp subTensorOp, PatternRewriter &rewriter) const override { - SmallVector mixedOffsets = subTensorOp.getMixedOffsets(); - SmallVector mixedSizes = subTensorOp.getMixedSizes(); - SmallVector mixedStrides = subTensorOp.getMixedStrides(); - auto hasValue = [](OpFoldResult valueOrAttr, int64_t val) { - auto attr = valueOrAttr.dyn_cast(); - return attr && attr.cast().getInt() == val; - }; - - if (llvm::any_of(mixedStrides, [&](OpFoldResult valueOrAttr) { - return !hasValue(valueOrAttr, 1); - })) + RankedTensorType resultType = subTensorOp.getType(); + auto reassociation = getReassociationMapForFoldingUnitDims(resultType); + if (!reassociation || + reassociation->size() == static_cast(resultType.getRank())) return failure(); + Location loc = subTensorOp.getLoc(); + SmallVector offsets = subTensorOp.getMixedOffsets(); + SmallVector sizes = subTensorOp.getMixedSizes(); + SmallVector strides = subTensorOp.getMixedStrides(); + auto rankReducedType = + SubTensorOp::inferRankReducedResultType(reassociation->size(), + subTensorOp.getSourceType(), + offsets, sizes, strides) + .cast(); + Value newSubTensor = rewriter.create( + loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides); + rewriter.replaceOpWithNewOp(subTensorOp, resultType, + newSubTensor, *reassociation); + return success(); + } +}; - // Find the expanded unit dimensions. - SmallVector reassociation; - SmallVector newOffsets, newSizes; - ArrayRef sourceShape = subTensorOp.getSourceType().getShape(); - ReassociationIndices curr; - for (int64_t dim : llvm::seq(0, mixedOffsets.size())) { - curr.push_back(dim); - if (sourceShape[dim] == 1 && hasValue(mixedOffsets[dim], 0) && - hasValue(mixedSizes[dim], 1)) { - continue; - } - newOffsets.push_back(mixedOffsets[dim]); - newSizes.push_back(mixedSizes[dim]); - reassociation.emplace_back(ReassociationIndices{}); - std::swap(reassociation.back(), curr); - } - if (newOffsets.size() == mixedOffsets.size()) +/// Convert `subtensor_insert` operations to rank-reduced versions. +struct UseRankReducedSubTensorInsertOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubTensorInsertOp insertOp, + PatternRewriter &rewriter) const override { + RankedTensorType sourceType = insertOp.getSourceType(); + auto reassociation = getReassociationMapForFoldingUnitDims(sourceType); + if (!reassociation || + reassociation->size() == static_cast(sourceType.getRank())) return failure(); - reassociation.back().append(curr.begin(), curr.end()); - SmallVector newStrides(newOffsets.size(), - rewriter.getI64IntegerAttr(1)); - Location loc = subTensorOp->getLoc(); - auto srcReshape = rewriter.create( - loc, subTensorOp.source(), reassociation); - auto newSubTensorOp = rewriter.create( - loc, srcReshape, newOffsets, newSizes, newStrides); - rewriter.replaceOpWithNewOp( - subTensorOp, subTensorOp.getType(), newSubTensorOp, reassociation); + Location loc = insertOp.getLoc(); + auto reshapedSource = rewriter.create( + loc, insertOp.source(), *reassociation); + rewriter.replaceOpWithNewOp( + insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); return success(); } }; - } // namespace /// Patterns that are used to canonicalize the use of unit-extent dims for @@ -623,8 +622,10 @@ RewritePatternSet &patterns) { auto *context = patterns.getContext(); patterns.add, FoldUnitDimLoops, - FoldUnitDimSubTensorOp, ReplaceUnitExtentTensors, - ReplaceUnitExtentTensors>(context); + ReplaceUnitExtentTensors, + ReplaceUnitExtentTensors, + UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>( + context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1882,6 +1882,26 @@ return res; } +/// Infer the canonical type of the result of a subtensor operation. Returns a +/// type with rank `resultRank` that is either the rank of the rank-reduced +/// type, or the non-rank-reduced type. +static MemRefType +getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, + ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides) { + auto resultType = + SubViewOp::inferRankReducedResultType( + resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) + .cast(); + if (resultType.getRank() != resultRank) { + resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, + mixedSizes, mixedStrides) + .cast(); + } + return resultType; +} + namespace { /// Pattern to rewrite a subview op with MemRefCast arguments. /// This essentially pushes memref.cast past its consuming subview when @@ -1921,7 +1941,7 @@ /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on /// the cast source operand type and the SubViewOp static information. This /// is the resulting type if the MemRefCastOp were folded. - auto resultType = SubViewOp::inferRankReducedResultType( + auto resultType = getCanonicalSubViewResultType( subViewOp.getType().getRank(), castOp.source().getType().cast(), subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), @@ -1937,6 +1957,17 @@ }; } // namespace +/// Return the canonical type of the result of a subview. +struct SubViewReturnTypeCanonicalizer { + MemRefType operator()(SubViewOp op, ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides) { + return getCanonicalSubViewResultType(op.getType().getRank(), + op.getSourceType(), mixedOffsets, + mixedSizes, mixedStrides); + } +}; + /// A canonicalizer wrapper to replace SubViewOps. struct SubViewCanonicalizer { void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { @@ -1946,9 +1977,10 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - SubViewOpMemRefCastFolder>(context); + results + .add, + SubViewOpMemRefCastFolder>(context); } OpFoldResult SubViewOp::fold(ArrayRef operands) { diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1917,6 +1917,25 @@ return produceSubTensorErrorMsg(result, op, expectedType); } +/// Infer the canonical type of the result of a subtensor operation. Returns a +/// type with rank `resultRank` that is either the rank of the rank-reduced +/// type, or the non-rank-reduced type. +static RankedTensorType getCanonicalSubTensorResultType( + unsigned resultRank, RankedTensorType sourceType, + ArrayRef mixedOffsets, ArrayRef mixedSizes, + ArrayRef mixedStrides) { + auto resultType = + SubTensorOp::inferRankReducedResultType( + resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) + .cast(); + if (resultType.getRank() != resultRank) { + resultType = SubTensorOp::inferResultType(sourceType, mixedOffsets, + mixedSizes, mixedStrides) + .cast(); + } + return resultType; +} + namespace { /// Pattern to rewrite a subtensor op with tensor::Cast arguments. /// This essentially pushes memref_cast past its consuming subtensor when @@ -1955,9 +1974,11 @@ /// on the cast source operand type and the SubTensorOp static information. /// This is the resulting type if the tensor::CastOp were folded and /// rank-reduced to the desired result rank. - auto resultType = SubTensorOp::inferRankReducedResultType( - subTensorOp.getType().getRank(), - castOp.source().getType().cast(), + + // If the rank of the infered result and subtensor op are different, the + // `tensor.cast` below will be illegal. + RankedTensorType resultType = getCanonicalSubTensorResultType( + subTensorOp.getType().getRank(), subTensorOp.getSourceType(), subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides()); Value newSubTensor = rewriter.create( @@ -1972,6 +1993,18 @@ }; } // namespace +/// Return the canonical type of the result of a subtensor. +struct SubTensorReturnTypeCanonicalizer { + RankedTensorType operator()(SubTensorOp op, + ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides) { + return getCanonicalSubTensorResultType(op.getType().getRank(), + op.getSourceType(), mixedOffsets, + mixedSizes, mixedStrides); + } +}; + /// A canonicalizer wrapper to replace SubTensorOps. struct SubTensorCanonicalizer { void operator()(PatternRewriter &rewriter, SubTensorOp op, @@ -1987,7 +2020,8 @@ void SubTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, + SubTensorOp, SubTensorReturnTypeCanonicalizer, + SubTensorCanonicalizer>, SubTensorOpCastFolder>(context); } @@ -2093,22 +2127,9 @@ canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. - Value source = subTensorInsertOp.source(); - RankedTensorType sourceType = source.getType().cast(); - SmallVector shape = llvm::to_vector<4>( - llvm::map_range(mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t { - if (auto attr = valueOrAttr.dyn_cast()) - return attr.cast().getInt(); - return ShapedType::kDynamicSize; - })); - RankedTensorType newSourceType = - RankedTensorType::get(shape, sourceType.getElementType()); - Location loc = subTensorInsertOp.getLoc(); - if (sourceType != newSourceType) - source = rewriter.create(loc, newSourceType, source); rewriter.replaceOpWithNewOp( - subTensorInsertOp, source, subTensorInsertOp.dest(), mixedOffsets, - mixedSizes, mixedStrides); + subTensorInsertOp, subTensorInsertOp.source(), subTensorInsertOp.dest(), + mixedOffsets, mixedSizes, mixedStrides); return success(); } }; @@ -2213,7 +2234,6 @@ SmallVectorImpl &caseOperands, SmallVectorImpl &caseOperandTypes, DenseIntElementsAttr &caseOperandOffsets) { - if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) || failed(parser.parseSuccessor(defaultDestination))) return failure(); @@ -2457,7 +2477,6 @@ /// ] static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter) { - SmallVector newCaseDests; SmallVector newCaseOperands; SmallVector> argStorage; diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -476,67 +476,32 @@ // ----- func @fold_subtensor( - %arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : index, %arg2 : index, - %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) - -> tensor<1x?x?x1x?x1x1xf32> { - %0 = subtensor %arg0[0, %arg1, %arg2, 0, %arg3, 0, 0] - [1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] : + %arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : tensor<1x?x?x?x?x1x1xf32>, + %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, + %arg6 : index, %arg7 : index) -> (tensor<1x?x?x1x?x1x1xf32>, tensor<1x?x?x1x?x1x1xf32>) { + %0 = subtensor %arg0[0, %arg2, %arg3, 0, %arg4, 0, 0] + [1, %arg5, %arg6, 1, %arg7, 1, 1] [1, 1, 1, 1, 1, 1, 1] : tensor<1x?x?x1x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32> - return %0 : tensor<1x?x?x1x?x1x1xf32> + %1 = subtensor %arg1[%arg2, 0, %arg3, 0, 0, %arg4, 0] + [1, %arg5, %arg6, 1, %arg7, 1, 1] [1, 1, 1, 1, 1, 1, 1] : + tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32> + return %0, %1 : tensor<1x?x?x1x?x1x1xf32>, tensor<1x?x?x1x?x1x1xf32> } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> // CHECK: func @fold_subtensor // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x?x1x?x1x1xf32> -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: index -// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: index -// CHECK-SAME: %[[ARG3:[a-z0-9]+]]: index -// CHECK-SAME: %[[ARG4:[a-z0-9]+]]: index -// CHECK-SAME: %[[ARG5:[a-z0-9]+]]: index -// CHECK-SAME: %[[ARG6:[a-z0-9]+]]: index -// CHECK: %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?x?x1x1xf32> +// CHECK: %[[SUBTENSOR1:.+]] = subtensor %[[ARG0]] +// CHECK-SAME: to tensor +// CHECK: %[[RESULT1:.+]] = linalg.tensor_reshape %[[SUBTENSOR1]] // CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]] -// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[ARG3]]] -// CHECK-SAME: [%[[ARG4]], %[[ARG5]], %[[ARG6]]] -// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]] +// CHECK: %[[SUBTENSOR2:.+]] = subtensor %[[ARG1]] +// CHECK-SAME: to tensor +// CHECK: %[[RESULT2:.+]] = linalg.tensor_reshape %[[SUBTENSOR2]] // CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK: return %[[RESULT_RESHAPE]] - -// ----- - -func @no_fold_subtensor( - %arg0 : tensor<1x?x?x?x?x1x1xf32>, %arg1 : index, %arg2 : index, - %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) - -> tensor<1x?x?x1x?x1x1xf32> { - %0 = subtensor %arg0[%arg1, 0, %arg2, 0, 0, %arg3, 0] - [1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] : - tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32> - return %0 : tensor<1x?x?x1x?x1x1xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6)> -// CHECK: func @no_fold_subtensor -// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x?x?x?x1x1xf32> -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: index -// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: index -// CHECK-SAME: %[[ARG3:[a-z0-9]+]]: index -// CHECK-SAME: %[[ARG4:[a-z0-9]+]]: index -// CHECK-SAME: %[[ARG5:[a-z0-9]+]]: index -// CHECK-SAME: %[[ARG6:[a-z0-9]+]]: index -// CHECK: %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] -// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]] -// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]] -// CHECK-SAME: [%[[ARG1]], 0, %[[ARG2]], 0, 0, %[[ARG3]]] -// CHECK-SAME: [1, %[[ARG4]], %[[ARG5]], 1, %[[ARG6]], 1] -// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]] -// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]] -// CHECK: return %[[RESULT_RESHAPE]] +// CHECK: return %[[RESULT1]], %[[RESULT2]] // ----- diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -62,3 +62,70 @@ %1 = memref.buffer_cast %0 : memref return %1 : memref } + +// ----- + +// CHECK-LABEL: func @subview_of_memcast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> +// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> +// CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> +// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> +func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ + %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref + %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : + memref to + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> + return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> +} + +// ----- + +// CHECK-LABEL: func @subview_of_static_full_size +// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> +// CHECK-NOT: memref.subview +// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8> +func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> { + %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8> + return %0 : memref<4x6x16x32xi8> +} + +// ----- + +#map0 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +func @subview_canonicalize(%arg0 : memref, %arg1 : index, + %arg2 : index) -> memref +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : memref to memref + return %0 : memref +} +// CHECK-LABEL: func @subview_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: memref +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : memref to memref<4x1x?xf32 +// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] +// CHEKC: return %[[RESULT]] + +// ----- + +#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +func @rank_reducing_subview_canonicalize(%arg0 : memref, %arg1 : index, + %arg2 : index) -> memref +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref to memref + return %0 : memref +} +// CHECK-LABEL: func @rank_reducing_subview_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: memref +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : memref to memref<4x?xf32 +// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] +// CHEKC: return %[[RESULT]] diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -154,30 +154,41 @@ // ----- -// CHECK-LABEL: func @subview_of_memcast -// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> -// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> -// CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> -// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> -func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> - memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ - %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref - %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : - memref to - memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> - return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> +func @subtensor_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor to tensor + return %0 : tensor } +// CHECK-LABEL: func @subtensor_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x1x?xf32> +// CHECK: %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]] +// CHEKC: return %[[RESULT]] // ----- -// CHECK-LABEL: func @subview_of_static_full_size -// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> -// CHECK-NOT: memref.subview -// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8> -func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> { - %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8> - return %0 : memref<4x6x16x32xi8> +func @rank_reducing_subtensor_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor to tensor + return %0 : tensor } +// CHECK-LABEL: func @rank_reducing_subtensor_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x?xf32> +// CHECK: %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]] +// CHEKC: return %[[RESULT]] // ----- @@ -232,7 +243,89 @@ // ----- -func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor, +func @subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %0 : tensor +} +// CHECK-LABEL: func @subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor to tensor + %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @subtensor_to_subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x1x?xf32> +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor<4x1x?xf32> into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @rank_reducing_subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %0 : tensor +} +// CHECK-LABEL: func @rank_reducing_subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @rank_reducing_subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor to tensor + %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @rank_reducing_subtensor_to_subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x?xf32> +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]] into %[[ARG3]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor<4x?xf32> into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor, %arg2 : index, %arg3 : index) -> tensor { %c0 = constant 0 : index %c1 = constant 1 : index @@ -247,7 +340,7 @@ %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor return %3 : tensor } -// CHECK-LABEL: func @subtensor_canonicalize +// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast // CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1] // CHECK-SAME: tensor<2x?xi32> into tensor // CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]