diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -257,6 +257,16 @@ ConversionPatternRewriter &rewriter) const override; }; +class ReinterpretCastPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -716,6 +726,52 @@ return success(); } +LogicalResult ReinterpretCastPattern::matchAndRewrite( + memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value src = adaptor.getSource(); + auto srcType = dyn_cast(src.getType()); + + if (!srcType) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "invalid src type " << src.getType(); + }); + + TypeConverter *converter = getTypeConverter(); + + auto dstType = converter->convertType(op.getType()); + if (dstType != srcType) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "invalid dst type " << op.getType(); + }); + + OpFoldResult offset = + getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter) + .front(); + if (isConstantIntValue(offset, 0)) { + rewriter.replaceOp(op, src); + return success(); + } + + Type intType = converter->convertType(rewriter.getIndexType()); + if (!intType) + return rewriter.notifyMatchFailure(op, "failed to convert index type"); + + Location loc = op.getLoc(); + auto offsetValue = [&]() -> Value { + if (auto val = dyn_cast(offset)) + return val; + + int64_t attrVal = cast(offset.get()).getInt(); + Attribute attr = rewriter.getIntegerAttr(intType, attrVal); + return rewriter.create(loc, intType, attr); + }(); + + rewriter.replaceOpWithNewOp( + op, src, offsetValue, std::nullopt); + return success(); +} + //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -723,9 +779,10 @@ namespace mlir { void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns + .add( + typeConverter, patterns.getContext()); } } // namespace mlir diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -520,3 +520,49 @@ } } // end module + +// ----- + +// Check reinterpret_casts + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: func.func @reinterpret_cast +// CHECK-SAME: (%[[MEM:.*]]: memref>, %[[OFF:.*]]: index) +func.func @reinterpret_cast(%arg: memref>, %arg1: index) -> memref, #spirv.storage_class> { +// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref> to !spirv.ptr +// CHECK: %[[OFF1:.*]] = builtin.unrealized_conversion_cast %[[OFF]] : index to i32 +// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF1]]] : !spirv.ptr, i32 +// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> +// CHECK: return %[[RET1]] + %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> + return %ret : memref, #spirv.storage_class> +} + +// CHECK-LABEL: func.func @reinterpret_cast_0 +// CHECK-SAME: (%[[MEM:.*]]: memref>) +func.func @reinterpret_cast_0(%arg: memref>) -> memref, #spirv.storage_class> { +// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref> to !spirv.ptr +// CHECK: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr to memref, #spirv.storage_class> +// CHECK: return %[[RET]] + %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> + return %ret : memref, #spirv.storage_class> +} + +// CHECK-LABEL: func.func @reinterpret_cast_5 +// CHECK-SAME: (%[[MEM:.*]]: memref>) +func.func @reinterpret_cast_5(%arg: memref>) -> memref, #spirv.storage_class> { +// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref> to !spirv.ptr +// CHECK: %[[OFF:.*]] = spirv.Constant 5 : i32 +// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr, i32 +// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> +// CHECK: return %[[RET1]] + %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> + return %ret : memref, #spirv.storage_class> +} + +} // end module