diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -3078,6 +3078,7 @@ let assemblyFormat = "$tensor attr-dict `:` type($memref)"; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3558,6 +3558,37 @@ return {}; } +namespace { +/// Replace tensor_cast + tensor_to_memref by tensor_to_memref + memref_cast. +struct TensorCastToMemref : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef, + PatternRewriter &rewriter) const final { + auto tensorCastOperand = + tensorToMemRef.getOperand().getDefiningOp(); + if (!tensorCastOperand) + return failure(); + auto srcTensorType = + tensorCastOperand.getOperand().getType().dyn_cast(); + if (!srcTensorType) + return failure(); + auto memrefType = MemRefType::get(srcTensorType.getShape(), + srcTensorType.getElementType()); + Value memref = rewriter.create( + tensorToMemRef.getLoc(), memrefType, tensorCastOperand.getOperand()); + rewriter.replaceOpWithNewOp(tensorToMemRef, + tensorToMemRef.getType(), memref); + return success(); + } +}; +} // namespace + +void TensorToMemrefOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -131,3 +131,15 @@ %2 = dim %0, %c1 : tensor return %1, %2: index, index } + +// CHECK-LABEL: func @tensor_cast_to_memref +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8> +// CHECK: %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8> +// CHECK: %[[M1:.+]] = memref_cast %[[M]] : memref<4x6x16x32xi8> to memref +// CHECK: return %[[M1]] : memref +func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) -> + memref { + %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor + %1 = tensor_to_memref %0 : memref + return %1 : memref +}