diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1090,6 +1090,8 @@ /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1350,6 +1350,35 @@ return success(); } +OpFoldResult ReinterpretCastOp::fold(ArrayRef operands) { + Value src = source(); + auto getPrevSrc = [&]() -> Value { + if (auto prev = src.getDefiningOp()) + return prev.source(); + + if (auto prev = src.getDefiningOp()) + return prev.source(); + + if (auto prev = src.getDefiningOp()) { + auto is0 = [](OpFoldResult val) { + auto i = getConstantIntValue(val); + return i && *i == 0; + }; + if (llvm::all_of(prev.getMixedOffsets(), is0)) + return prev.source(); + } + + return nullptr; + }; + + if (auto prevSrc = getPrevSrc()) { + sourceMutable().assign(prevSrc); + return *this; + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // Reassociative reshape ops //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -552,3 +552,39 @@ // CHECK-LABEL: func @self_copy // CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @reinterpret_of_reinterpret +// CHECK-SAME: (%[[ARG:.*]]: memref, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] +// CHECK: return %[[RES]] +func @reinterpret_of_reinterpret(%arg : memref, %size1: index, %size2: index) -> memref { + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size1], strides: [1] : memref to memref + %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref to memref + return %1 : memref +} + +// ----- + +// CHECK-LABEL: func @reinterpret_of_cast +// CHECK-SAME: (%[[ARG:.*]]: memref, %[[SIZE:.*]]: index) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE]]], strides: [1] +// CHECK: return %[[RES]] +func @reinterpret_of_cast(%arg : memref, %size: index) -> memref { + %0 = memref.cast %arg : memref to memref<5xi8> + %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size], strides: [1] : memref<5xi8> to memref + return %1 : memref +} + +// ----- + +// CHECK-LABEL: func @reinterpret_of_subview +// CHECK-SAME: (%[[ARG:.*]]: memref, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] +// CHECK: return %[[RES]] +func @reinterpret_of_subview(%arg : memref, %size1: index, %size2: index) -> memref { + %0 = memref.subview %arg[0] [%size1] [1] : memref to memref + %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref to memref + return %1 : memref +}