diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" namespace mlir { @@ -428,13 +429,68 @@ return success(); } }; + +template +bool replaceConstantUsesOf(PatternRewriter &rewriter, + memref::ExtractStridedMetadataOp metadataOp, + ArrayRef maybeConstants, + ResultsContainer results, + llvm::function_ref isDynamic) { + bool atLeastOneReplacement = false; + for (auto [maybeConstant, result] : llvm::zip(maybeConstants, results)) { + // Don't materialize a constant if there are no uses: this would indice + // infinite loops in the driver. + if (isDynamic(maybeConstant) || result.use_empty()) + continue; + Value constantVal = rewriter.create( + metadataOp.getLoc(), maybeConstant); + for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { + rewriter.startRootUpdate(op); + // updateRootInplace: lambda cannot capture structured bindings in C++17 + // yet. + op->replaceUsesOfWith(result, constantVal); + rewriter.finalizeRootUpdate(op); + atLeastOneReplacement = true; + } + } + return atLeastOneReplacement; +}; + +// Forward propagate all constants information from an ExtractStridedMetadataOp. +struct ForwardStaticMetadata + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp, + PatternRewriter &rewriter) const override { + auto memrefType = metadataOp.getSource().getType().cast(); + SmallVector strides; + int64_t offset; + LogicalResult res = getStridesAndOffset(memrefType, strides, offset); + assert(succeeded(res) && "must be a strided memref type"); + + bool atLeastOneReplacement = replaceConstantUsesOf( + rewriter, metadataOp, ArrayRef(offset), + ArrayRef>(metadataOp.getOffset()), + ShapedType::isDynamicStrideOrOffset); + atLeastOneReplacement |= + replaceConstantUsesOf(rewriter, metadataOp, memrefType.getShape(), + metadataOp.getSizes(), ShapedType::isDynamic); + atLeastOneReplacement |= replaceConstantUsesOf( + rewriter, metadataOp, strides, metadataOp.getStrides(), + ShapedType::isDynamicStrideOrOffset); + + return success(atLeastOneReplacement); + } +}; } // namespace void memref::populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns + .add( + patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -1,5 +1,24 @@ // RUN: mlir-opt --simplify-extract-strided-metadata -split-input-file %s -o - | FileCheck %s +// CHECK-LABEL: func @extract_strided_metadata_constants +// CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32, strided<[4, 1], offset: 2>>) +func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4, 1], offset: 2>>) + -> (memref, index, index, index, index, index) { + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index + + // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %base : + memref<5x4xf32, strided<[4,1], offset:2>> + -> memref, index, index, index, index, index + + // CHECK: %[[BASE]], %[[C2]], %[[C5]], %[[C4]], %[[C4]], %[[C1]] + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + // ----- // Check that we simplify extract_strided_metadata of subview to