diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1090,6 +1090,8 @@ /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -53,6 +53,9 @@ /// If ofr is a constant integer or an IntegerAttr, return the integer. Optional getConstantIntValue(OpFoldResult ofr); +/// Return true if `ofr` is constant integer equal to `value`. +bool isConstantIntValue(OpFoldResult ofr, int64_t value); + /// Return true if ofr1 and ofr2 are the same integer constant attribute values /// or the same SSA value. /// Ignore integer bitwitdh and type mismatch that come from the fact there is diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1508,6 +1508,36 @@ return success(); } +OpFoldResult ReinterpretCastOp::fold(ArrayRef /*operands*/) { + Value src = source(); + auto getPrevSrc = [&]() -> Value { + // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x). + if (auto prev = src.getDefiningOp()) + return prev.source(); + + // reinterpret_cast(cast(x)) -> reinterpret_cast(x). + if (auto prev = src.getDefiningOp()) + return prev.source(); + + // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets + // are 0. + if (auto prev = src.getDefiningOp()) + if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) { + return isConstantIntValue(val, 0); + })) + return prev.source(); + + return nullptr; + }; + + if (auto prevSrc = getPrevSrc()) { + sourceMutable().assign(prevSrc); + return getResult(); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // Reassociative reshape ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -81,6 +81,12 @@ return llvm::None; } +/// Return true if `ofr` is constant integer equal to `value`. +bool isConstantIntValue(OpFoldResult ofr, int64_t value) { + auto val = getConstantIntValue(ofr); + return val && *val == value; +} + /// Return true if ofr1 and ofr2 are the same integer constant attribute values /// or the same SSA value. /// Ignore integer bitwidth and type mismatch that come from the fact there is diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -657,3 +657,39 @@ // CHECK: func @scopeInline // CHECK-NOT: memref.alloca_scope + +// ----- + +// CHECK-LABEL: func @reinterpret_of_reinterpret +// CHECK-SAME: (%[[ARG:.*]]: memref, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] +// CHECK: return %[[RES]] +func @reinterpret_of_reinterpret(%arg : memref, %size1: index, %size2: index) -> memref { + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size1], strides: [1] : memref to memref + %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref to memref + return %1 : memref +} + +// ----- + +// CHECK-LABEL: func @reinterpret_of_cast +// CHECK-SAME: (%[[ARG:.*]]: memref, %[[SIZE:.*]]: index) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE]]], strides: [1] +// CHECK: return %[[RES]] +func @reinterpret_of_cast(%arg : memref, %size: index) -> memref { + %0 = memref.cast %arg : memref to memref<5xi8> + %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size], strides: [1] : memref<5xi8> to memref + return %1 : memref +} + +// ----- + +// CHECK-LABEL: func @reinterpret_of_subview +// CHECK-SAME: (%[[ARG:.*]]: memref, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] +// CHECK: return %[[RES]] +func @reinterpret_of_subview(%arg : memref, %size1: index, %size2: index) -> memref { + %0 = memref.subview %arg[0] [%size1] [1] : memref to memref + %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref to memref + return %1 : memref +}