diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1708,6 +1708,53 @@ return nullptr; } +/// Given a value \p A and \p B and their potential alternative constant +/// representation \p constantA and \p constantB, check if \p A is equal to +/// \p B. +/// \p isDynamic is used to check if \p constantA and \p constantB are valid +/// constants. +static bool sameValueOrConstant(OpFoldResult A, int64_t constantA, + OpFoldResult B, int64_t constantB, + llvm::function_ref isDynamic) { + Value valA = A.is() ? A.get() : Value(); + Value valB = B.is() ? B.get() : Value(); + if (valA != Value() && valA == valB) + return true; + + auto getConstant = [isDynamic](OpFoldResult val, + int64_t constant) -> Optional { + Optional maybeConstant; + if (isDynamic(constant)) { + if (val.is()) + maybeConstant = val.get().cast().getInt(); + else + maybeConstant = getConstantIntValue(val.get()); + } else { + maybeConstant = constant; + } + return maybeConstant; + }; + + Optional maybeConstantA = getConstant(A, constantA); + Optional maybeConstantB = getConstant(B, constantB); + if (!maybeConstantA || !maybeConstantB) + return false; + return *maybeConstantA == *maybeConstantB; +} + +/// Check if the \p idx-th OpFoldResult of A and B are equal. +/// \see sameValueOrConstant for more details. +static bool sameValueOrConstantAtIdx( + ArrayRef valuesA, ArrayRef constantsA, + ArrayRef valuesB, ArrayRef constantsB, + llvm::function_ref isDynamic, unsigned idx) { + OpFoldResult A = valuesA[idx]; + OpFoldResult B = valuesB[idx]; + int64_t constantA = constantsA[idx]; + int64_t constantB = constantsB[idx]; + return sameValueOrConstant(A, constantA, B, constantB, isDynamic); +} + namespace { /// Replace reinterpret_cast(extract_strided_metadata memref) -> memref. struct ReinterpretCastOpExtractStridedMetadataFolder @@ -1725,24 +1772,61 @@ // properties as the extract strided metadata. // First, check that the strides are the same. - if (extractStridedMetadata.getStrides().size() != op.getStrides().size()) + SmallVector extractStridesOfr = + getAsOpFoldResult(extractStridedMetadata.getStrides()); + SmallVector reinterpretStridesOfr = op.getMixedStrides(); + if (extractStridesOfr.size() != reinterpretStridesOfr.size()) return failure(); - for (auto [extractStride, reinterpretStride] : - llvm::zip(extractStridedMetadata.getStrides(), op.getStrides())) - if (extractStride != reinterpretStride) + + SmallVector reinterpretStrides; + int64_t reinterpretOffset; + bool hasKnownStridesAndOffset = succeeded(getStridesAndOffset( + op.getType(), reinterpretStrides, reinterpretOffset)); + (void)hasKnownStridesAndOffset; + assert(hasKnownStridesAndOffset && + "getStridesAndOffset must work on valid reinterpret_cast"); + SmallVector extractStrides; + int64_t extractOffset; + auto sourceType = + extractStridedMetadata.getSource().getType().cast(); + hasKnownStridesAndOffset = succeeded( + getStridesAndOffset(extractStridedMetadata.getSource().getType(), + extractStrides, extractOffset)); + (void)hasKnownStridesAndOffset; + assert(hasKnownStridesAndOffset && "getStridesAndOffset must work on valid " + "extract_strided_metadata source"); + + unsigned rank = op.getType().getRank(); + for (unsigned i = 0; i < rank; ++i) + if (!sameValueOrConstantAtIdx(extractStridesOfr, extractStrides, + reinterpretStridesOfr, reinterpretStrides, + ShapedType::isDynamicStrideOrOffset, i)) return failure(); // Second, check the sizes. - if (extractStridedMetadata.getSizes().size() != op.getSizes().size()) - return failure(); - for (auto [extractSize, reinterpretSize] : - llvm::zip(extractStridedMetadata.getSizes(), op.getSizes())) - if (extractSize != reinterpretSize) - return failure(); + assert(extractStridedMetadata.getSizes().size() == + op.getMixedSizes().size() && + "Strides and sizes rank must match"); + ArrayRef reinterpretSizes = op.getType().getShape(); + ArrayRef extractSizes = sourceType.getShape(); + SmallVector extractSizesOfr = + getAsOpFoldResult(extractStridedMetadata.getSizes()); + SmallVector reinterpretSizesOfr = op.getMixedSizes(); + for (unsigned i = 0; i < rank; ++i) { + if (!sameValueOrConstantAtIdx(extractSizesOfr, extractSizes, + reinterpretSizesOfr, reinterpretSizes, + ShapedType::isDynamic, i)) + return failure(); + } // Finally, check the offset. - if (op.getOffsets().size() != 1 && - extractStridedMetadata.getOffset() != *op.getOffsets().begin()) + assert(op.getMixedOffsets().size() == 1 && + "reinterpret_cast with more than one offset should have been " + "rejected by the verifier"); + if (!sameValueOrConstant( + getAsOpFoldResult(extractStridedMetadata.getOffset()), + extractOffset, op.getMixedOffsets()[0], reinterpretOffset, + ShapedType::isDynamicStrideOrOffset)) return failure(); // At this point, we know that the back and forth between extract strided diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -770,6 +770,23 @@ // ----- +// Similar to reinterpret_of_extract_strided_metadata_w_type_mistach except that +// we check that the match happens when the static information has been folded. +// E.g., in this case, we know that size of dim 0 is 8 and size of dim 1 is 2. +// So even if we don't use the values sizes#0, and sizes#1, as long as they have +// the same constant value, the match is valid. +// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants +// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) +// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref) -> memref> { + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index + %c8 = arith.constant 8: index + %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref to memref> + return %m2 : memref> +} +// ----- + // Check that a reinterpret cast of an equivalent extract strided metadata // is completely removed when the original memref has the same type. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type