diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -321,8 +321,11 @@ type with a compatible shape. The source and destination types are compatible if: - a. Both are ranked memref types with the same element type, address space, - and rank and: + a. Both memrefs have compatible element types: element type is the same or + differs only it the signed/signless/unsigned bit. + + b. Both are ranked memref types with compatible element types, address + space, and rank and: 1. Both have the same layout or both have compatible strided layouts. 2. The individual sizes (resp. offset and strides in the case of strided memrefs) may convert constant dimensions to dynamic dimensions and @@ -351,8 +354,8 @@ memref<12x4xf32, strided<[?, ?], offset: ?>> ``` - b. Either or both memref types are unranked with the same element type, and - address space. + c. Either or both memref types are unranked with compatible element type, + and the same address space. Example: @@ -710,18 +713,18 @@ This operation is also useful for completeness to the existing memref.dim op. While accessing strides, offsets and the base pointer independently is not - available, this is useful for composing with its natural complement op: + available, this is useful for composing with its natural complement op: `memref.reinterpret_cast`. Intended Use Cases: The main use case is to expose the logic for manipulate memref metadata at a - higher level than the LLVM dialect. + higher level than the LLVM dialect. This makes lowering more progressive and brings the following benefits: - not all users of MLIR want to lower to LLVM and the information to e.g. lower to library calls---like libxsmm---or to SPIR-V was not available. - - foldings and canonicalizations can happen at a higher level in MLIR: - before this op existed, lowering to LLVM would create large amounts of + - foldings and canonicalizations can happen at a higher level in MLIR: + before this op existed, lowering to LLVM would create large amounts of LLVMIR. Even when LLVM does a good job at folding the low-level IR from a performance perspective, it is unnecessarily opaque and inefficient to send unkempt IR to LLVM. @@ -729,11 +732,11 @@ Example: ```mlir - %base, %offset, %sizes:2, %strides:2 = - memref.extract_strided_metadata %memref : + %base, %offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %memref : memref<10x?xf32>, index, index, index, index, index - // After folding, the type of %m2 can be memref<10x?xf32> and further + // After folding, the type of %m2 can be memref<10x?xf32> and further // folded to %memref. %m2 = memref.reinterpret_cast %base to offset: [%offset], diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1244,17 +1244,27 @@ } /// This is a common class used for patterns of the form -/// "someop(memrefcast) -> someop". It folds the source of any memref.cast -/// into the root operation directly. +/// "someop(memrefcast) -> someop". It folds the source of a memref.cast +/// into the root operation directly if source and user have the same memref +/// element type. static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; + + // Returns true iff memref types have the same element type. + auto sameEltType = [](Value a, Value b) -> bool { + auto aEltType = a.getType().cast().getElementType(); + auto bEltType = b.getType().cast().getElementType(); + return aEltType == bEltType; + }; + for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); - if (cast) { - operand.set(cast.getOperand()); - folded = true; - } + if (!cast || !sameEltType(operand.get(), cast.getOperand())) + continue; + operand.set(cast.getOperand()); + folded = true; } + return success(folded); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -572,8 +572,15 @@ auto uaT = a.dyn_cast(); auto ubT = b.dyn_cast(); + // Strips signed/unsigned bit from integer types. + auto stripSign = [](Type type) -> Type { + if (auto integer = type.dyn_cast()) + return IntegerType::get(type.getContext(), integer.getWidth()); + return type; + }; + if (aT && bT) { - if (aT.getElementType() != bT.getElementType()) + if (stripSign(aT.getElementType()) != stripSign(bT.getElementType())) return false; if (aT.getLayout() != bT.getLayout()) { int64_t aOffset, bOffset; @@ -621,7 +628,7 @@ auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); - if (aEltType != bEltType) + if (stripSign(aEltType) != stripSign(bEltType)) return false; auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -89,6 +89,14 @@ return } +// CHECK-LABEL: func @memref_cast +func.func @memref_cast() { + %0 = memref.alloc() : memref<2xi32> + %1 = memref.cast %0 : memref<2xi32> to memref<2xui32> + %2 = memref.cast %1 : memref<2xui32> to memref<2xsi32> + %3 = memref.cast %2 : memref<2xsi32> to memref<*xi32> + return +} // CHECK-LABEL: func @memref_alloca_scope func.func @memref_alloca_scope() { @@ -288,7 +296,7 @@ memref> %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] : - memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>> into + memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>> into memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>> %rr1 = memref.collapse_shape %r1 [[0, 1], [2], [3, 4]] : memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>> into @@ -333,7 +341,7 @@ // ----- -func.func @extract_strided_metadata(%memref : memref<10x?xf32>) +func.func @extract_strided_metadata(%memref : memref<10x?xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %memref