diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -257,6 +257,64 @@ ConversionPatternRewriter &rewriter) const override; }; +class ReinterpretCastPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = adaptor.getSource(); + auto srcType = dyn_cast(src.getType()); + + if (!srcType) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "invalid src type " << src.getType(); + }); + + auto converter = getTypeConverter(); + assert(converter && "Invalid type converter"); + + auto dstType = dyn_cast_or_null( + converter->convertType(op.getType())); + if (dstType != srcType) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "invalid dst type " << op.getType(); + }); + + OpFoldResult offset = getMixedValues(adaptor.getStaticOffsets(), + adaptor.getOffsets(), rewriter) + .front(); + if (isConstantIntValue(offset, 0)) { + rewriter.replaceOp(op, src); + return success(); + } + + auto intType = converter->convertType(rewriter.getIndexType()); + if (!intType) + return rewriter.notifyMatchFailure(op, "invalid index type"); + + auto loc = op.getLoc(); + auto offsetValue = [&]() -> Value { + if (auto val = dyn_cast(offset)) + return val; + + auto attrVal = cast(offset.get()).getInt(); + auto attr = rewriter.getIntegerAttr(intType, attrVal); + return rewriter.create(loc, intType, attr); + }(); + + auto ptr = rewriter + .create( + loc, src, offsetValue, std::nullopt) + .getResult(); + + rewriter.replaceOp(op, ptr); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -731,9 +789,10 @@ namespace mlir { void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns + .add( + typeConverter, patterns.getContext()); } } // namespace mlir diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -479,3 +479,49 @@ } } // end module + +// ----- + +// Check reinterpret_casts + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: func.func @reinterpret_cast +// CHECK-SAME: (%[[MEM:.*]]: memref>, %[[OFF:.*]]: index) +func.func @reinterpret_cast(%arg: memref>, %arg1: index) -> memref, #spirv.storage_class> { +// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref> to !spirv.ptr +// CHECK: %[[OFF1:.*]] = builtin.unrealized_conversion_cast %[[OFF]] : index to i32 +// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF1]]] : !spirv.ptr, i32 +// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> +// CHECK: return %[[RET1]] + %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> + return %ret : memref, #spirv.storage_class> +} + +// CHECK-LABEL: func.func @reinterpret_cast_0 +// CHECK-SAME: (%[[MEM:.*]]: memref>) +func.func @reinterpret_cast_0(%arg: memref>) -> memref, #spirv.storage_class> { +// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref> to !spirv.ptr +// CHECK: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr to memref, #spirv.storage_class> +// CHECK: return %[[RET]] + %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> + return %ret : memref, #spirv.storage_class> +} + +// CHECK-LABEL: func.func @reinterpret_cast_5 +// CHECK-SAME: (%[[MEM:.*]]: memref>) +func.func @reinterpret_cast_5(%arg: memref>) -> memref, #spirv.storage_class> { +// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref> to !spirv.ptr +// CHECK: %[[OFF:.*]] = spirv.Constant 5 : i32 +// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr, i32 +// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> +// CHECK: return %[[RET1]] + %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> + return %ret : memref, #spirv.storage_class> +} + +} // end module