diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -912,6 +912,8 @@ let assemblyFormat = [{ $source `:` type($source) `->` type(results) attr-dict }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1286,6 +1286,58 @@ } } +/// Helper function to perform the replacement of all constant uses of `values` +/// by a materialized constant extracted from `maybeConstants`. +/// `values` and `maybeConstants` are expected to have the same size. +template +static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, + Container values, + ArrayRef maybeConstants, + llvm::function_ref isDynamic) { + assert(values.size() == maybeConstants.size() && + " expected values and maybeConstants of the same size"); + bool atLeastOneReplacement = false; + for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) { + // Don't materialize a constant if there are no uses: this would indice + // infinite loops in the driver. + if (isDynamic(maybeConstant) || result.use_empty()) + continue; + Value constantVal = + rewriter.create(loc, maybeConstant); + for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { + // updateRootInplace: lambda cannot capture structured bindings in C++17 + // yet. + op->replaceUsesOfWith(result, constantVal); + atLeastOneReplacement = true; + } + } + return atLeastOneReplacement; +} + +LogicalResult +ExtractStridedMetadataOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { + OpBuilder builder(*this); + auto memrefType = getSource().getType().cast(); + SmallVector strides; + int64_t offset; + LogicalResult res = getStridesAndOffset(memrefType, strides, offset); + (void)res; + assert(succeeded(res) && "must be a strided memref type"); + + bool atLeastOneReplacement = replaceConstantUsesOf( + builder, getLoc(), ArrayRef>(getOffset()), + ArrayRef(offset), ShapedType::isDynamicStrideOrOffset); + atLeastOneReplacement |= + replaceConstantUsesOf(builder, getLoc(), getSizes(), + memrefType.getShape(), ShapedType::isDynamic); + atLeastOneReplacement |= + replaceConstantUsesOf(builder, getLoc(), getStrides(), strides, + ShapedType::isDynamicStrideOrOffset); + + return success(atLeastOneReplacement); +} + //===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -550,64 +550,6 @@ } }; -/// Helper function to perform the replacement of all constant uses of `values` -/// by a materialized constant extracted from `maybeConstants`. -/// `values` and `maybeConstants` are expected to have the same size. -template -bool replaceConstantUsesOf(PatternRewriter &rewriter, Location loc, - Container values, ArrayRef maybeConstants, - llvm::function_ref isDynamic) { - assert(values.size() == maybeConstants.size() && - " expected values and maybeConstants of the same size"); - bool atLeastOneReplacement = false; - for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) { - // Don't materialize a constant if there are no uses: this would indice - // infinite loops in the driver. - if (isDynamic(maybeConstant) || result.use_empty()) - continue; - Value constantVal = - rewriter.create(loc, maybeConstant); - for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { - rewriter.startRootUpdate(op); - // updateRootInplace: lambda cannot capture structured bindings in C++17 - // yet. - op->replaceUsesOfWith(result, constantVal); - rewriter.finalizeRootUpdate(op); - atLeastOneReplacement = true; - } - } - return atLeastOneReplacement; -} - -// Forward propagate all constants information from an ExtractStridedMetadataOp. -struct ForwardStaticMetadata - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp, - PatternRewriter &rewriter) const override { - auto memrefType = metadataOp.getSource().getType().cast(); - SmallVector strides; - int64_t offset; - LogicalResult res = getStridesAndOffset(memrefType, strides, offset); - (void)res; - assert(succeeded(res) && "must be a strided memref type"); - - bool atLeastOneReplacement = replaceConstantUsesOf( - rewriter, metadataOp.getLoc(), - ArrayRef>(metadataOp.getOffset()), - ArrayRef(offset), ShapedType::isDynamicStrideOrOffset); - atLeastOneReplacement |= replaceConstantUsesOf( - rewriter, metadataOp.getLoc(), metadataOp.getSizes(), - memrefType.getShape(), ShapedType::isDynamic); - atLeastOneReplacement |= replaceConstantUsesOf( - rewriter, metadataOp.getLoc(), metadataOp.getStrides(), strides, - ShapedType::isDynamicStrideOrOffset); - - return success(atLeastOneReplacement); - } -}; - /// Replace `base, offset, sizes, strides = /// extract_strided_metadata(allocLikeOp)` /// @@ -753,7 +695,6 @@ memref::ExpandShapeOp, getExpandedSizes, getExpandedStrides>, ExtractStridedMetadataOpReshapeFolder< memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>, - ForwardStaticMetadata, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpAllocFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -758,9 +758,13 @@ // Check that a reinterpret cast of an equivalent extract strided metadata // is canonicalized to a plain cast when the destination type is different // than the type of the original memref. +// This pattern is currently defeated by the constant folding that happens +// with extract_strided_metadata. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index @@ -773,12 +777,12 @@ // Check that a reinterpret cast of an equivalent extract strided metadata // is completely removed when the original memref has the same type. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type -// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) +// CHECK-SAME: (%[[ARG:.*]]: memref) -> memref<8x2xf32> { - %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index - %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref<8x2xf32> - return %m2 : memref<8x2xf32> +func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref>) -> memref> { + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref> -> memref, index, index, index, index, index + %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref> + return %m2 : memref> } // ----- @@ -787,8 +791,10 @@ // when the strides don't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [4, 2, 2], strides: [1, 1, %[[STRIDES]]#1] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]] // CHECK: return %[[RES]] func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index @@ -801,8 +807,11 @@ // when the offset doesn't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[SIZES]]#0, %[[SIZES]]#1], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1] +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] // CHECK: return %[[RES]] func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -193,10 +193,10 @@ -> (memref, index, index, index, index, index) { %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] : - memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + memref<8x16x4xf32> to memref<6x3xf32, strided<[?, 1], offset: 210>> %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : - memref<6x3xf32, strided<[4, 1], offset: 210>> + memref<6x3xf32, strided<[?, 1], offset: 210>> -> memref, index, index, index, index, index return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :