diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2788,6 +2788,16 @@ ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); + static Type inferRankReducedResultType(unsigned resultRank, + MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static Type inferRankReducedResultType(unsigned resultRank, + MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. @@ -2914,6 +2924,16 @@ ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); + static Type inferRankReducedResultType(unsigned resultRank, + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static Type inferRankReducedResultType(unsigned resultRank, + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. @@ -3027,6 +3047,7 @@ /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; } }]; + let hasFolder = 1; } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2891,8 +2891,68 @@ dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, - staticSizes, staticStrides) - .cast(); + staticSizes, staticStrides); +} + +static void +getPositionsOfShapeOne(unsigned rank, ArrayRef shape, + llvm::SmallDenseSet &dimsToProject) { + dimsToProject.reserve(rank); + for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) { + if (shape[pos] == 1) { + dimsToProject.insert(pos); + --rank; + } + } +} + +Type SubViewOp::inferRankReducedResultType( + unsigned resultRank, MemRefType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { + auto inferredType = + inferResultType(sourceRankedTensorType, leadingStaticOffsets, + leadingStaticSizes, leadingStaticStrides) + .cast(); + assert(inferredType.getRank() >= resultRank && "expected "); + int rankDiff = inferredType.getRank() - resultRank; + if (rankDiff > 0) { + auto shape = inferredType.getShape(); + llvm::SmallDenseSet dimsToProject; + getPositionsOfShapeOne(rankDiff, shape, dimsToProject); + SmallVector projectedShape; + for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) + if (!dimsToProject.contains(pos)) + projectedShape.push_back(shape[pos]); + + AffineMap map; + auto maps = inferredType.getAffineMaps(); + if (!maps.empty() && maps.front()) + map = getProjectedMap(maps.front(), dimsToProject); + inferredType = + MemRefType::get(projectedShape, inferredType.getElementType(), map, + inferredType.getMemorySpace()); + } + return inferredType; +} + +Type SubViewOp::inferRankReducedResultType( + unsigned resultRank, MemRefType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { + SmallVector staticOffsets, staticSizes, staticStrides; + SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; + dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, + staticOffsets, ShapedType::kDynamicStrideOrOffset); + dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, + ShapedType::kDynamicSize); + dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, + staticStrides, ShapedType::kDynamicStrideOrOffset); + return SubViewOp::inferRankReducedResultType( + resultRank, sourceRankedTensorType, staticOffsets, staticSizes, + staticStrides); } // Build a SubViewOp with mixed static and dynamic entries and custom result @@ -3439,12 +3499,40 @@ return success(); } }; + +// Rewrite memref_cast + subview as subview + memref_cast. +struct SubViewFromMemRefCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const final { + auto memrefCastOp = subViewOp.source().getDefiningOp(); + if (!memrefCastOp || !canFoldIntoConsumerOp(memrefCastOp)) + return failure(); + + // Let type inference do the work for us. + auto srcMemrefType = memrefCastOp.source().getType().cast(); + auto dstMemrefType = subViewOp.getType(); + auto subviewMemRefType = + SubViewOp::inferRankReducedResultType( + dstMemrefType.getRank(), srcMemrefType, subViewOp.getMixedOffsets(), + subViewOp.getMixedSizes(), subViewOp.getMixedStrides()) + .cast(); + Value newSubview = rewriter.create( + subViewOp.getLoc(), subviewMemRefType, memrefCastOp.source(), + subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), + subViewOp.getMixedStrides()); + rewriter.replaceOpWithNewOp(subViewOp, subViewOp.getType(), + newSubview); + return success(); + } +}; } // namespace void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert, - SubViewOpMemRefCastFolder>(context); + SubViewFromMemRefCast>(context); } OpFoldResult SubViewOp::fold(ArrayRef operands) { @@ -3492,8 +3580,52 @@ dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); return SubTensorOp::inferResultType(sourceRankedTensorType, staticOffsets, - staticSizes, staticStrides) - .cast(); + staticSizes, staticStrides); +} + +/// A subtensor result type can be fully inferred from the source type and the +/// static representation of offsets, sizes and strides. Special sentinels +/// encode the dynamic case. +Type SubTensorOp::inferRankReducedResultType( + unsigned resultRank, RankedTensorType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { + auto inferredType = + inferResultType(sourceRankedTensorType, leadingStaticOffsets, + leadingStaticSizes, leadingStaticStrides) + .cast(); + int rankDiff = inferredType.getRank() - resultRank; + if (rankDiff > 0) { + auto shape = inferredType.getShape(); + llvm::SmallDenseSet dimsToProject; + getPositionsOfShapeOne(rankDiff, shape, dimsToProject); + SmallVector projectedShape; + for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) + if (!dimsToProject.contains(pos)) + projectedShape.push_back(shape[pos]); + inferredType = + RankedTensorType::get(projectedShape, inferredType.getElementType()); + } + return inferredType; +} + +Type SubTensorOp::inferRankReducedResultType( + unsigned resultRank, RankedTensorType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { + SmallVector staticOffsets, staticSizes, staticStrides; + SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; + dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, + staticOffsets, ShapedType::kDynamicStrideOrOffset); + dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, + ShapedType::kDynamicSize); + dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, + staticStrides, ShapedType::kDynamicStrideOrOffset); + return SubTensorOp::inferRankReducedResultType( + resultRank, sourceRankedTensorType, staticOffsets, staticSizes, + staticStrides); } // Build a SubTensorOp with mixed static and dynamic entries and custom result diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -176,3 +176,16 @@ return %0 : tensor<4x6x16x32xi8> } +// CHECK-LABEL: func @rank_reducing_subview_of_memcast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> +// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> +// CHECK: %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> +// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> +func @rank_reducing_subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ + %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref + %1 = subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : + memref to + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> + return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> +}