diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -2126,4 +2126,40 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// EltBitcastOp +//===----------------------------------------------------------------------===// + +def EltBitcastOp : MemRef_Op<"elt_bitcast", [ + Pure, ViewLikeOpInterface, SameOperandsShape, MemRefsNormalizable + ]> { + let summary = "memref elementwise bitcast op"; + let description = [{ + `elt_bitcast` op reinterprets an existing memref as a memref with a + different element type. Source and destination memrefs must have the same + shape, layout, memory space and element size. The source memref should be + allocated with the proper alignment and memory type for accessing + destination memref elements. Memref data is not modified or copied by this + op. + + Example: + + ```mlir + %dst = memref.elt_bitcast %src : memref to memref> + ``` + }]; + + let arguments = (ins AnyMemRef:$source); + let results = (outs AnyMemRef:$result); + + let extraClassDeclaration = [{ + ::mlir::Value getViewSource() { return getSource(); } + }]; + + let hasFolder = 1; + let hasVerifier = 1; + + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; +} + #endif // MEMREF_OPS diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3159,6 +3159,46 @@ return OpFoldResult(); } +//===----------------------------------------------------------------------===// +// EltBitcastOp +//===----------------------------------------------------------------------===// + +LogicalResult EltBitcastOp::verify() { + auto srcType = getSource().getType().cast(); + auto dstType = getResult().getType().cast(); + if (srcType.getShape() != dstType.getShape()) + return emitError("memref bitcast shape mismatch"); + if (srcType.getLayout() != dstType.getLayout()) + return emitError("memref bitcast layout mismatch"); + if (srcType.getMemorySpace() != dstType.getMemorySpace()) + return emitError("memref bitcast memory space mismatch"); + + Type srcElem = srcType.getElementType(); + Type dstElem = dstType.getElementType(); + if (srcElem.isIntOrFloat() && dstElem.isIntOrFloat() && + srcElem.getIntOrFloatBitWidth() != dstElem.getIntOrFloatBitWidth()) + return emitError("memref bitcast element size mismatch"); + return success(); +} + +OpFoldResult EltBitcastOp::fold(llvm::ArrayRef /*operands*/) { + Value src = getSource(); + Type dstType = getResult().getType(); + if (src.getType() == dstType) + return src; + + // Search cast chain for the matching type. + while (auto parent = src.getDefiningOp()) { + Value parentSource = parent.getSource(); + if (parentSource.getType() == dstType) + return parentSource; + + src = parentSource; + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -846,3 +846,25 @@ memref.store %v, %0[%i2] : memref<4xf32> return %src : memref<2xf32> } + +// ----- + +// CHECK-LABEL: func @memref_bitcast_same_type +// CHECK-SAME: (%[[ARG:.*]]: memref) +// CHECK: return %[[ARG]] +func.func @memref_bitcast_same_type(%src : memref) -> memref { + %0 = memref.elt_bitcast %src : memref to memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: func @memref_bitcast_chain +// CHECK-SAME: (%[[ARG:.*]]: memref) +// CHECK: return %[[ARG]] +func.func @memref_bitcast_chain(%src : memref) -> memref { + %0 = memref.elt_bitcast %src : memref to memref + %1 = memref.elt_bitcast %0 : memref to memref> + %2 = memref.elt_bitcast %1 : memref> to memref + return %2 : memref +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -380,3 +380,14 @@ %0 = memref.extract_aligned_pointer_as_index %src : memref -> index return %0 : index } + +// ----- + +// CHECK-LABEL: func @memref_bitcast +// CHECK-SAME: (%[[ARG:.*]]: memref) +// CHECK: %[[RES:.*]] = memref.elt_bitcast %[[ARG]] : memref to memref> +// CHECK: return %[[RES]] +func.func @memref_bitcast(%src : memref) -> memref> { + %0 = memref.elt_bitcast %src : memref to memref> + return %0 : memref> +}