diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" namespace mlir { @@ -428,13 +429,72 @@ return success(); } }; + +/// Helper function to perform the replacement of all constant uses of `values` +/// by a materialized constant extracted from `maybeConstants`. +/// `values` and `maybeConstants` are expected to have the same size. +template +bool replaceConstantUsesOf(PatternRewriter &rewriter, Location loc, + Container values, ArrayRef maybeConstants, + llvm::function_ref isDynamic) { + assert(values.size() == maybeConstants.size() && + " expected values and maybeConstants of the same size"); + bool atLeastOneReplacement = false; + for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) { + // Don't materialize a constant if there are no uses: this would indice + // infinite loops in the driver. + if (isDynamic(maybeConstant) || result.use_empty()) + continue; + Value constantVal = + rewriter.create(loc, maybeConstant); + for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { + rewriter.startRootUpdate(op); + // updateRootInplace: lambda cannot capture structured bindings in C++17 + // yet. + op->replaceUsesOfWith(result, constantVal); + rewriter.finalizeRootUpdate(op); + atLeastOneReplacement = true; + } + } + return atLeastOneReplacement; +} + +// Forward propagate all constants information from an ExtractStridedMetadataOp. +struct ForwardStaticMetadata + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp, + PatternRewriter &rewriter) const override { + auto memrefType = metadataOp.getSource().getType().cast(); + SmallVector strides; + int64_t offset; + LogicalResult res = getStridesAndOffset(memrefType, strides, offset); + (void)res; + assert(succeeded(res) && "must be a strided memref type"); + + bool atLeastOneReplacement = replaceConstantUsesOf( + rewriter, metadataOp.getLoc(), + ArrayRef>(metadataOp.getOffset()), + ArrayRef(offset), ShapedType::isDynamicStrideOrOffset); + atLeastOneReplacement |= replaceConstantUsesOf( + rewriter, metadataOp.getLoc(), metadataOp.getSizes(), + memrefType.getShape(), ShapedType::isDynamic); + atLeastOneReplacement |= replaceConstantUsesOf( + rewriter, metadataOp.getLoc(), metadataOp.getStrides(), strides, + ShapedType::isDynamicStrideOrOffset); + + return success(atLeastOneReplacement); + } +}; } // namespace void memref::populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns + .add( + patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -1,5 +1,24 @@ // RUN: mlir-opt --simplify-extract-strided-metadata -split-input-file %s -o - | FileCheck %s +// CHECK-LABEL: func @extract_strided_metadata_constants +// CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32, strided<[4, 1], offset: 2>>) +func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4, 1], offset: 2>>) + -> (memref, index, index, index, index, index) { + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index + + // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %base : + memref<5x4xf32, strided<[4,1], offset:2>> + -> memref, index, index, index, index, index + + // CHECK: %[[BASE]], %[[C2]], %[[C5]], %[[C4]], %[[C4]], %[[C1]] + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + // ----- // Check that we simplify extract_strided_metadata of subview to