diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3838,11 +3838,34 @@ return success(); } }; + +/// Canonicalize tensor_load + tensor_to_memref to memref_cast when type +/// mismatches prevent `TensorToMemrefOp::fold` to kick in. +struct TensorLoadToMemref : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef, + PatternRewriter &rewriter) const final { + auto tensorLoad = tensorToMemRef.tensor().getDefiningOp(); + // Bail unless we have a tensor_load + tensor_to_memref with different + // types. `TensorToMemrefOp::fold` handles the same type case. + if (!tensorLoad || + tensorLoad.memref().getType() == tensorToMemRef.getType()) + return failure(); + // If types are not cast-compatible, bail. + if (!MemRefCastOp::areCastCompatible(tensorLoad.memref().getType(), + tensorToMemRef.getType())) + return failure(); + rewriter.replaceOpWithNewOp( + tensorToMemRef, tensorToMemRef.getType(), tensorLoad.memref()); + return success(); + } +}; } // namespace void TensorToMemrefOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt %s -canonicalize | FileCheck %s +// RUN: mlir-opt %s -canonicalize --split-input-file | FileCheck %s + +// ----- // Test case: Basic folding of tensor_load(tensor_to_memref(t)) -> t // CHECK-LABEL: func @tensor_load_of_tensor_to_memref( @@ -10,6 +12,8 @@ return %1 : tensor } +// ----- + // Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m // CHECK-LABEL: func @tensor_to_memref_of_tensor_load( // CHECK-SAME: %[[MEMREF:.*]]: memref) -> memref { @@ -20,7 +24,11 @@ return %1 : memref } +// ----- + // Test case: If the memrefs are not the same type, don't fold them. +// Test case: If the memrefs are not cast-compatible (e.g. different address space), +// don't canonicalize them either. // CHECK-LABEL: func @no_fold_tensor_to_memref_of_tensor_load( // CHECK-SAME: %[[MEMREF_ADDRSPACE2:.*]]: memref) -> memref { // CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF_ADDRSPACE2]] : memref @@ -32,6 +40,28 @@ return %1 : memref } +// ----- + +// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)> +// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)> + +// Test case: If the memrefs are cast-compatible, canonicalize. +// CHECK-LABEL: func @canonicalize_tensor_to_memref_of_tensor_load( +// CHECK-SAME: %[[M:.*]]: memref) -> memref { +// CHECK-NOT: tensor_load +// CHECK-NOT: tensor_to_memref +// CHECK: %[[R:.*]] = memref_cast %[[M]] : memref to memref +// CHECK: return %[[R]] +func @canonicalize_tensor_to_memref_of_tensor_load(%arg0: memref) + -> memref +{ + %0 = tensor_load %arg0 : memref + %1 = tensor_to_memref %0 : memref + return %1 : memref +} + +// ----- + // Test case: Basic folding of dim(tensor_load(m)) -> dim(m). // CHECK-LABEL: func @dim_of_tensor_load( // CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref @@ -45,6 +75,8 @@ return %1 : index } +// ----- + // Test case: Folding of load(tensor_to_memref(%v, %idxs)) // -> tensor.extract(%v, %idx) // CHECK-LABEL: func @load_from_tensor_to_memref( @@ -59,6 +91,8 @@ return %1 : f32 } +// ----- + // Test case: Folding of dim(tensor.generate %idx) -> %idx // CHECK-LABEL: func @dim_of_tensor.generate( // CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index @@ -74,6 +108,8 @@ return %1 : index } +// ----- + // Test case: Folding of comparisons with equal operands. // CHECK-LABEL: @cmpi_equal_operands // CHECK-DAG: %[[T:.*]] = constant true @@ -96,6 +132,8 @@ : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 } +// ----- + // Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx] // CHECK-LABEL: func @dim_of_memref_reshape( // CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, @@ -116,6 +154,8 @@ return %1 : index } +// ----- + // Test case: Folding dim(tensor.cast %0, %idx) -> dim %0, %idx // CHECK-LABEL: func @fold_dim_of_tensor.cast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32> @@ -132,6 +172,8 @@ return %1, %2: index, index } +// ----- + // CHECK-LABEL: func @tensor_cast_to_memref // CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8> // CHECK: %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8> @@ -144,6 +186,8 @@ return %1 : memref } +// ----- + // CHECK-LABEL: func @subview_of_memcast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> // CHECK: %[[S:.+]] = subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> @@ -158,6 +202,8 @@ return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> } +// ----- + // CHECK-LABEL: func @trivial_subtensor // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> // CHECK-NOT: subtensor @@ -167,6 +213,8 @@ return %0 : tensor<4x6x16x32xi8> } +// ----- + // CHECK-LABEL: func @trivial_subtensor_insert // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> // CHECK-NOT: subtensor @@ -176,6 +224,8 @@ return %0 : tensor<4x6x16x32xi8> } +// ----- + // CHECK-LABEL: func @rank_reducing_tensor_of_cast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> // CHECK: %[[S:.+]] = subtensor %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>