diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2234,6 +2234,7 @@ operand_range getIndices() { return {operand_begin() + 1, operand_end()}; } }]; + let hasCanonicalizer = 1; let hasFolder = 1; let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2293,6 +2293,30 @@ return OpFoldResult(); } +namespace { +/// Fold a load on a tensor_to_memref operation into an extract_element on the +/// corresponding tensor. +struct LoadOfTensorToMemref : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LoadOp load, + PatternRewriter &rewriter) const override { + auto tensorToMemref = load.memref().getDefiningOp(); + if (!tensorToMemref) + return failure(); + + rewriter.replaceOpWithNewOp(load, tensorToMemref.tensor(), + load.indices()); + return success(); + } +}; +} // end anonymous namespace. + +void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // MemRefCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -45,6 +45,20 @@ return %1 : index } +// Test case: Folding of load(tensor_to_memref(%v, %idxs)) +// -> extract_element(%v, %idx) +// CHECK-LABEL: func @load_from_tensor_to_memref( +// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index +// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor +// CHECK: %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]] +// CHECK-NOT: load +// CHECK: return %[[RES]] : f32 +func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor) -> f32 { + %0 = tensor_to_memref %arg2 : memref + %1 = load %0[%arg0, %arg1] : memref + return %1 : f32 +} + // Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx // CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements( // CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index