diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -913,6 +913,23 @@ $source `:` type($source) `->` type(results) attr-dict }]; + let extraClassDeclaration = [{ + /// Return a vector of all the static or dynamic sizes of the op, while + /// statically inferring the sizes of the dynamic sizes, when possible. + /// This is best effort. + /// E.g., if `getSizes` returns `[%dyn_size0, %dyn_size1]`, but the + /// source memref type is `memref<2x8xi16>`, this method will + /// return `[2, 8]`. + /// Similarly if the resulting memref type is `memref<2x?xi16>`, but + /// `%dyn_size1` can statically be pinned to a constant value, this + /// constant value is returned instead of `%dyn_size`. + SmallVector getConstifiedMixedSizes(); + /// Similar to `getConstifiedMixedSizes` but for strides. + SmallVector getConstifiedMixedStrides(); + /// Similar to `getConstifiedMixedSizes` but for the offset. + OpFoldResult getConstifiedMixedOffset(); + }]; + let hasFolder = 1; } @@ -1301,6 +1318,20 @@ /// Return the number of leading operands before the `offsets`, `sizes` and /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } + + /// Return a vector of all the static or dynamic sizes of the op, while + /// statically inferring the sizes of the dynamic sizes, when possible. + /// This is best effort. + /// E.g., if `getMixedSizes` returns `[2, %dyn_size]`, but the resulting + /// memref type is `memref<2x8xi16>`, this method will return `[2, 8]`. + /// Similarly if the resulting memref type is `memref<2x?xi16>`, but + /// `%dyn_size` can statically be pinned to a constant value, this + /// constant value is returned instead of `%dyn_size`. + SmallVector getConstifiedMixedSizes(); + /// Similar to `getConstifiedMixedSizes` but for strides. + SmallVector getConstifiedMixedStrides(); + /// Similar to `getConstifiedMixedSizes` but for the offset. + OpFoldResult getConstifiedMixedOffset(); }]; let hasFolder = 1; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -109,6 +109,105 @@ return NoneType::get(type.getContext()); } +//===----------------------------------------------------------------------===// +// Utility functions for propagating static information +//===----------------------------------------------------------------------===// + +/// Helper function that infers the constant values from a list of \p values, +/// a \p memRefTy, and another helper function \p getAttributes. +/// The inferred constant values replace the related `OpFoldResult` in +/// \p values. +/// +/// \note This function shouldn't be used directly, instead, use the +/// `getConstifiedMixedXXX` methods from the related operations. +/// +/// \p getAttributes retuns a list of potentially constant values, as determined +/// by \p isDynamic, from the given \p memRefTy. The returned list must have as +/// many elements as \p values or be empty. +/// +/// E.g., consider the following example: +/// ``` +/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] : +/// memref to memref> +/// ``` +/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`. +/// Now using this helper function with: +/// - `values == [2, %dyn_stride]`, +/// - `memRefTy == memref>` +/// - `getAttributes == getConstantStrides` (i.e., a wrapper around +/// `getStridesAndOffset`), and +/// - `isDynamic == isDynamicStrideOrOffset` +/// Will yield: `values == [2, 1]` +static void constifyIndexValues( + SmallVectorImpl &values, MemRefType memRefTy, + MLIRContext *ctxt, + llvm::function_ref(MemRefType)> getAttributes, + llvm::function_ref isDynamic) { + SmallVector constValues = getAttributes(memRefTy); + Builder builder(ctxt); + for (const auto &it : llvm::enumerate(constValues)) { + int64_t constValue = it.value(); + if (!isDynamic(constValue)) + values[it.index()] = builder.getIndexAttr(constValue); + } + for (OpFoldResult &ofr : values) { + if (ofr.is()) { + // FIXME: We shouldn't need to do that, but right now, the static indices + // are created with the wrong type: `i64` instead of `index`. + // As a result, if we were to keep the attribute as is, we may fail to see + // that two attributes are equal because one would have the i64 type and + // the other the index type. + // The alternative would be to create constant indices with getI64Attr in + // this and the previous loop, but it doesn't logically make sense (we are + // dealing with indices here) and would only strenghten the inconsistency + // around how static indices are created (some places use getI64Attr, + // others use getIndexAttr). + // The workaround here is to stick to the IndexAttr type for all the + // values, hence we recreate the attribute even when it is already static + // to make sure the type is consistent. + ofr = builder.getIndexAttr( + ofr.get().cast().getInt()); + continue; + } + Optional maybeConstant = getConstantIntValue(ofr.get()); + if (maybeConstant) + ofr = builder.getIndexAttr(*maybeConstant); + } +} + +/// Wrapper around `getShape` that conforms to the function signature +/// expected for `getAttributes` in `constifyIndexValues`. +static SmallVector getConstantSizes(MemRefType memRefTy) { + ArrayRef sizes = memRefTy.getShape(); + return SmallVector(sizes.begin(), sizes.end()); +} + +/// Wrapper around `getStridesAndOffset` that returns only the offset and +/// conforms to the function signature expected for `getAttributes` in +/// `constifyIndexValues`. +static SmallVector getConstantOffset(MemRefType memrefType) { + SmallVector strides; + int64_t offset; + LogicalResult hasStaticInformation = + getStridesAndOffset(memrefType, strides, offset); + if (failed(hasStaticInformation)) + return SmallVector(); + return SmallVector(1, offset); +} + +/// Wrapper around `getStridesAndOffset` that returns only the strides and +/// conforms to the function signature expected for `getAttributes` in +/// `constifyIndexValues`. +static SmallVector getConstantStrides(MemRefType memrefType) { + SmallVector strides; + int64_t offset; + LogicalResult hasStaticInformation = + getStridesAndOffset(memrefType, strides, offset); + if (failed(hasStaticInformation)) + return SmallVector(); + return strides; +} + //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// @@ -1293,18 +1392,22 @@ template static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, - ArrayRef maybeConstants, - llvm::function_ref isDynamic) { + ArrayRef maybeConstants) { assert(values.size() == maybeConstants.size() && " expected values and maybeConstants of the same size"); bool atLeastOneReplacement = false; for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) { // Don't materialize a constant if there are no uses: this would indice // infinite loops in the driver. - if (isDynamic(maybeConstant) || result.use_empty()) + if (result.use_empty() || maybeConstant == getAsOpFoldResult(result)) continue; - Value constantVal = - rewriter.create(loc, maybeConstant); + assert(maybeConstant.template is() && + "The constified value should be either unchanged (i.e., == result) " + "or a constant"); + Value constantVal = rewriter.create( + loc, maybeConstant.template get() + .template cast() + .getInt()); for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { // updateRootInplace: lambda cannot capture structured bindings in C++17 // yet. @@ -1319,26 +1422,41 @@ ExtractStridedMetadataOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { OpBuilder builder(*this); - auto memrefType = getSource().getType().cast(); - SmallVector strides; - int64_t offset; - LogicalResult res = getStridesAndOffset(memrefType, strides, offset); - (void)res; - assert(succeeded(res) && "must be a strided memref type"); bool atLeastOneReplacement = replaceConstantUsesOf( builder, getLoc(), ArrayRef>(getOffset()), - ArrayRef(offset), ShapedType::isDynamicStrideOrOffset); - atLeastOneReplacement |= - replaceConstantUsesOf(builder, getLoc(), getSizes(), - memrefType.getShape(), ShapedType::isDynamic); - atLeastOneReplacement |= - replaceConstantUsesOf(builder, getLoc(), getStrides(), strides, - ShapedType::isDynamicStrideOrOffset); + getConstifiedMixedOffset()); + atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(), + getConstifiedMixedSizes()); + atLeastOneReplacement |= replaceConstantUsesOf( + builder, getLoc(), getStrides(), getConstifiedMixedStrides()); return success(atLeastOneReplacement); } +SmallVector ExtractStridedMetadataOp::getConstifiedMixedSizes() { + SmallVector values = getAsOpFoldResult(getSizes()); + constifyIndexValues(values, getSource().getType(), getContext(), + getConstantSizes, ShapedType::isDynamic); + return values; +} + +SmallVector +ExtractStridedMetadataOp::getConstifiedMixedStrides() { + SmallVector values = getAsOpFoldResult(getStrides()); + constifyIndexValues(values, getSource().getType(), getContext(), + getConstantStrides, ShapedType::isDynamicStrideOrOffset); + return values; +} + +OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() { + OpFoldResult offsetOfr = getAsOpFoldResult(getOffset()); + SmallVector values(1, offsetOfr); + constifyIndexValues(values, getSource().getType(), getContext(), + getConstantOffset, ShapedType::isDynamicStrideOrOffset); + return values[0]; +} + //===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// @@ -1781,8 +1899,67 @@ return nullptr; } +SmallVector ReinterpretCastOp::getConstifiedMixedSizes() { + SmallVector values = getMixedSizes(); + constifyIndexValues(values, getType(), getContext(), getConstantSizes, + ShapedType::isDynamic); + return values; +} + +SmallVector ReinterpretCastOp::getConstifiedMixedStrides() { + SmallVector values = getMixedStrides(); + constifyIndexValues(values, getType(), getContext(), getConstantStrides, + ShapedType::isDynamicStrideOrOffset); + return values; +} + +OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() { + SmallVector values = getMixedOffsets(); + assert(values.size() == 1 && + "reinterpret_cast must have one and only one offset"); + constifyIndexValues(values, getType(), getContext(), getConstantOffset, + ShapedType::isDynamicStrideOrOffset); + return values[0]; +} + namespace { -/// Replace reinterpret_cast(extract_strided_metadata memref) -> memref. +/// Replace the sequence: +/// ``` +/// base, offset, sizes, strides = extract_strided_metadata src +/// dst = reinterpret_cast base to offset, sizes, strides +/// ``` +/// With +/// +/// ``` +/// dst = memref.cast src +/// ``` +/// +/// Note: The cast operation is only inserted when the type of dst and src +/// are not the same. E.g., when going from <4xf32> to . +/// +/// This pattern also matches when the offset, sizes, and strides don't come +/// directly from the `extract_strided_metadata`'s results but it can be +/// statically proven that they would hold the same values. +/// +/// For instance, the following sequence would be replaced: +/// ``` +/// base, offset, sizes, strides = +/// extract_strided_metadata memref : memref<3x4xty> +/// dst = reinterpret_cast base to 0, [3, 4], strides +/// ``` +/// Because we know (thanks to the type of the input memref) that variable +/// `offset` and `sizes` will respectively hold 0 and [3, 4]. +/// +/// Similarly, the following sequence would be replaced: +/// ``` +/// c0 = arith.constant 0 +/// c4 = arith.constant 4 +/// base, offset, sizes, strides = +/// extract_strided_metadata memref : memref<3x4xty> +/// dst = reinterpret_cast base to c0, [3, c4], strides +/// ``` +/// Because we know that `offset`and `c0` will hold 0 +/// and `c4` will hold 4. struct ReinterpretCastOpExtractStridedMetadataFolder : public OpRewritePattern { public: @@ -1798,24 +1975,39 @@ // properties as the extract strided metadata. // First, check that the strides are the same. - if (extractStridedMetadata.getStrides().size() != op.getStrides().size()) + SmallVector extractStridesOfr = + extractStridedMetadata.getConstifiedMixedStrides(); + SmallVector reinterpretStridesOfr = + op.getConstifiedMixedStrides(); + if (extractStridesOfr.size() != reinterpretStridesOfr.size()) return failure(); - for (auto [extractStride, reinterpretStride] : - llvm::zip(extractStridedMetadata.getStrides(), op.getStrides())) - if (extractStride != reinterpretStride) + + unsigned rank = op.getType().getRank(); + for (unsigned i = 0; i < rank; ++i) { + if (extractStridesOfr[i] != reinterpretStridesOfr[i]) return failure(); + } // Second, check the sizes. - if (extractStridedMetadata.getSizes().size() != op.getSizes().size()) - return failure(); - for (auto [extractSize, reinterpretSize] : - llvm::zip(extractStridedMetadata.getSizes(), op.getSizes())) - if (extractSize != reinterpretSize) + assert(extractStridedMetadata.getSizes().size() == + op.getMixedSizes().size() && + "Strides and sizes rank must match"); + SmallVector extractSizesOfr = + extractStridedMetadata.getConstifiedMixedSizes(); + SmallVector reinterpretSizesOfr = + op.getConstifiedMixedSizes(); + for (unsigned i = 0; i < rank; ++i) { + if (extractSizesOfr[i] != reinterpretSizesOfr[i]) return failure(); - + } // Finally, check the offset. - if (op.getOffsets().size() != 1 && - extractStridedMetadata.getOffset() != *op.getOffsets().begin()) + assert(op.getMixedOffsets().size() == 1 && + "reinterpret_cast with more than one offset should have been " + "rejected by the verifier"); + OpFoldResult extractOffsetOfr = + extractStridedMetadata.getConstifiedMixedOffset(); + OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset(); + if (extractOffsetOfr != reinterpretOffsetOfr) return failure(); // At this point, we know that the back and forth between extract strided diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -758,13 +758,9 @@ // Check that a reinterpret cast of an equivalent extract strided metadata // is canonicalized to a plain cast when the destination type is different // than the type of the original memref. -// This pattern is currently defeated by the constant folding that happens -// with extract_strided_metadata. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]] -// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]] +// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index @@ -774,6 +770,23 @@ // ----- +// Similar to reinterpret_of_extract_strided_metadata_w_type_mistach except that +// we check that the match happen when the static information has been folded. +// E.g., in this case, we know that size of dim 0 is 8 and size of dim 1 is 2. +// So even if we don't use the values sizes#0, sizes#1, as long as they have the +// same constant value, the match is valid. +// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants +// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) +// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref) -> memref> { + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index + %c8 = arith.constant 8: index + %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref to memref> + return %m2 : memref> +} +// ----- + // Check that a reinterpret cast of an equivalent extract strided metadata // is completely removed when the original memref has the same type. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type