diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -63,6 +63,8 @@ ResultTypeFunc resultTypeFunc; auto resultType = resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides); + if (!resultType) + return failure(); auto newOp = rewriter.create(op.getLoc(), resultType, op.source(), mixedOffsets, mixedSizes, mixedStrides); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -511,14 +511,16 @@ /// dimension is dropped the stride must be dropped too. static llvm::Optional> computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, - ArrayAttr staticSizes) { + ArrayRef sizes) { llvm::SmallDenseSet unusedDims; if (originalType.getRank() == reducedType.getRank()) return unusedDims; - for (auto dim : llvm::enumerate(staticSizes)) - if (dim.value().cast().getInt() == 1) - unusedDims.insert(dim.index()); + for (auto dim : llvm::enumerate(sizes)) + if (auto attr = dim.value().dyn_cast()) + if (attr.cast().getInt() == 1) + unusedDims.insert(dim.index()); + SmallVector originalStrides, candidateStrides; int64_t originalOffset, candidateOffset; if (failed( @@ -574,7 +576,7 @@ MemRefType sourceType = getSourceType(); MemRefType resultType = getType(); llvm::Optional> unusedDims = - computeMemRefRankReductionMask(sourceType, resultType, static_sizes()); + computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); assert(unusedDims && "unable to find unused dims of subview"); return *unusedDims; } @@ -1712,7 +1714,7 @@ /// not matching dimension must be 1. static SubViewVerificationResult isRankReducedType(Type originalType, Type candidateReducedType, - ArrayAttr staticSizes, std::string *errMsg = nullptr) { + ArrayRef sizes, std::string *errMsg = nullptr) { if (originalType == candidateReducedType) return SubViewVerificationResult::Success; if (!originalType.isa()) @@ -1737,7 +1739,7 @@ MemRefType candidateReduced = candidateReducedType.cast(); auto optionalUnusedDimsMask = - computeMemRefRankReductionMask(original, candidateReduced, staticSizes); + computeMemRefRankReductionMask(original, candidateReduced, sizes); // Sizes cannot be matched in case empty vector is returned. if (!optionalUnusedDimsMask.hasValue()) @@ -1807,7 +1809,7 @@ std::string errMsg; auto result = - isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg); + isRankReducedType(expectedType, subViewType, op.getMixedSizes(), &errMsg); return produceSubViewErrorMsg(result, op, expectedType, errMsg); } @@ -1848,21 +1850,29 @@ /// Infer the canonical type of the result of a subview operation. Returns a /// type with rank `resultRank` that is either the rank of the rank-reduced /// type, or the non-rank-reduced type. -static MemRefType -getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, - ArrayRef mixedOffsets, - ArrayRef mixedSizes, - ArrayRef mixedStrides) { - auto resultType = - SubViewOp::inferRankReducedResultType( - resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) - .cast(); - if (resultType.getRank() != resultRank) { - resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, - mixedSizes, mixedStrides) - .cast(); +static MemRefType getCanonicalSubViewResultType( + MemRefType currentResultType, MemRefType sourceType, + ArrayRef mixedOffsets, ArrayRef mixedSizes, + ArrayRef mixedStrides) { + auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, + mixedSizes, mixedStrides) + .cast(); + llvm::Optional> unusedDims = + computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes); + // Return nullptr as failure mode. + if (!unusedDims) + return nullptr; + SmallVector shape; + for (auto sizes : llvm::enumerate(nonRankReducedType.getShape())) { + if (unusedDims->count(sizes.index())) + continue; + shape.push_back(sizes.value()); } - return resultType; + AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); + if (!layoutMap.isIdentity()) + layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); + return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, + nonRankReducedType.getMemorySpace()); } namespace { @@ -1905,8 +1915,7 @@ /// the cast source operand type and the SubViewOp static information. This /// is the resulting type if the MemRefCastOp were folded. auto resultType = getCanonicalSubViewResultType( - subViewOp.getType().getRank(), - castOp.source().getType().cast(), + subViewOp.getType(), castOp.source().getType().cast(), subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), subViewOp.getMixedStrides()); Value newSubView = rewriter.create( @@ -1925,9 +1934,9 @@ MemRefType operator()(SubViewOp op, ArrayRef mixedOffsets, ArrayRef mixedSizes, ArrayRef mixedStrides) { - return getCanonicalSubViewResultType(op.getType().getRank(), - op.getSourceType(), mixedOffsets, - mixedSizes, mixedStrides); + return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), + mixedOffsets, mixedSizes, + mixedStrides); } }; diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -47,7 +47,7 @@ // ----- -#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> func @rank_reducing_subview_canonicalize(%arg0 : memref, %arg1 : index, %arg2 : index) -> memref { @@ -395,3 +395,25 @@ %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref into memref return %collapsed : memref } + +// ----- + +func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index) + -> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> { + %c0 = arith.constant 0 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %0 = memref.subview %arg0[%arg1, %arg1, %arg1, 0] [%c1, %c4, %c1, 1] [1, 1, 1, 1] + : memref<2x5x7x1xf32> to memref (d0 * 35 + s0 + d1 * 7 + d2)>> + %1 = memref.cast %0 + : memref (d0 * 35 + s0 + d1 * 7 + d2)>> to + memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> + return %1 : memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> +} + +// CHECK-LABEL: func @reduced_memref +// CHECK: %[[RESULT:.+]] = memref.subview +// CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}> +// CHECK: return %[[RESULT]]