diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -913,6 +913,8 @@ let assemblyFormat = [{ $source `:` type($source) `->` type(results) attr-dict }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1266,6 +1266,58 @@ } } +/// Helper function to perform the replacement of all constant uses of `values` +/// by a materialized constant extracted from `maybeConstants`. +/// `values` and `maybeConstants` are expected to have the same size. +template +static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, + Container values, + ArrayRef maybeConstants, + llvm::function_ref isDynamic) { + assert(values.size() == maybeConstants.size() && + " expected values and maybeConstants of the same size"); + bool atLeastOneReplacement = false; + for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) { + // Don't materialize a constant if there are no uses: this would indice + // infinite loops in the driver. + if (isDynamic(maybeConstant) || result.use_empty()) + continue; + Value constantVal = + rewriter.create(loc, maybeConstant); + for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { + // updateRootInplace: lambda cannot capture structured bindings in C++17 + // yet. + op->replaceUsesOfWith(result, constantVal); + atLeastOneReplacement = true; + } + } + return atLeastOneReplacement; +} + +LogicalResult +ExtractStridedMetadataOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { + OpBuilder builder(*this); + auto memrefType = getSource().getType().cast(); + SmallVector strides; + int64_t offset; + LogicalResult res = getStridesAndOffset(memrefType, strides, offset); + (void)res; + assert(succeeded(res) && "must be a strided memref type"); + + bool atLeastOneReplacement = replaceConstantUsesOf( + builder, getLoc(), ArrayRef>(getOffset()), + ArrayRef(offset), ShapedType::isDynamicStrideOrOffset); + atLeastOneReplacement |= + replaceConstantUsesOf(builder, getLoc(), getSizes(), + memrefType.getShape(), ShapedType::isDynamic); + atLeastOneReplacement |= + replaceConstantUsesOf(builder, getLoc(), getStrides(), strides, + ShapedType::isDynamicStrideOrOffset); + + return success(atLeastOneReplacement); +} + //===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -553,64 +553,6 @@ } }; -/// Helper function to perform the replacement of all constant uses of `values` -/// by a materialized constant extracted from `maybeConstants`. -/// `values` and `maybeConstants` are expected to have the same size. -template -bool replaceConstantUsesOf(PatternRewriter &rewriter, Location loc, - Container values, ArrayRef maybeConstants, - llvm::function_ref isDynamic) { - assert(values.size() == maybeConstants.size() && - " expected values and maybeConstants of the same size"); - bool atLeastOneReplacement = false; - for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) { - // Don't materialize a constant if there are no uses: this would indice - // infinite loops in the driver. - if (isDynamic(maybeConstant) || result.use_empty()) - continue; - Value constantVal = - rewriter.create(loc, maybeConstant); - for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { - rewriter.startRootUpdate(op); - // updateRootInplace: lambda cannot capture structured bindings in C++17 - // yet. - op->replaceUsesOfWith(result, constantVal); - rewriter.finalizeRootUpdate(op); - atLeastOneReplacement = true; - } - } - return atLeastOneReplacement; -} - -// Forward propagate all constants information from an ExtractStridedMetadataOp. -struct ForwardStaticMetadata - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp, - PatternRewriter &rewriter) const override { - auto memrefType = metadataOp.getSource().getType().cast(); - SmallVector strides; - int64_t offset; - LogicalResult res = getStridesAndOffset(memrefType, strides, offset); - (void)res; - assert(succeeded(res) && "must be a strided memref type"); - - bool atLeastOneReplacement = replaceConstantUsesOf( - rewriter, metadataOp.getLoc(), - ArrayRef>(metadataOp.getOffset()), - ArrayRef(offset), ShapedType::isDynamicStrideOrOffset); - atLeastOneReplacement |= replaceConstantUsesOf( - rewriter, metadataOp.getLoc(), metadataOp.getSizes(), - memrefType.getShape(), ShapedType::isDynamic); - atLeastOneReplacement |= replaceConstantUsesOf( - rewriter, metadataOp.getLoc(), metadataOp.getStrides(), strides, - ShapedType::isDynamicStrideOrOffset); - - return success(atLeastOneReplacement); - } -}; - /// Replace `base, offset, sizes, strides = /// extract_strided_metadata(allocLikeOp)` /// @@ -728,7 +670,6 @@ memref::ExpandShapeOp, getExpandedSizes, getExpandedStrides>, ExtractStridedMetadataOpReshapeFolder< memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>, - ForwardStaticMetadata, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpAllocFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp>( diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -804,8 +804,10 @@ // when the strides don't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [4, 2, 2], strides: [1, 1, %[[STRIDES]]#1] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]] // CHECK: return %[[RES]] func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index @@ -818,8 +820,11 @@ // when the offset doesn't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[SIZES]]#0, %[[SIZES]]#1], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1] +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] // CHECK: return %[[RES]] func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -173,12 +173,16 @@ // Sub offsets: [3, 4, 2] // => Final offset: 3 * 64 + 4 * 4 * %stride + 2 * 1 + 0 == 16 * %stride + 194 // +// Then, the folding pattern of extract_strided_metadata uses the statically +// known information to replace the use of some values, in that case the final +// offset: 210. +// // CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)> -// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0] -> (s0 * 16 + 194)> // CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides // CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, // CHECK-SAME: %[[DYN_STRIDE:.*]]: index) // +// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -186,18 +190,17 @@ // CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] // // CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]] -// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[DYN_STRIDE]]] // -// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] +// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides( %base: memref<8x16x4xf32>, %stride: index) -> (memref, index, index, index, index, index) { %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] : - memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + memref<8x16x4xf32> to memref<6x3xf32, strided<[?, 1], offset: 210>> %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : - memref<6x3xf32, strided<[4, 1], offset: 210>> + memref<6x3xf32, strided<[?, 1], offset: 210>> -> memref, index, index, index, index, index return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :