diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -267,6 +267,28 @@ ConversionPatternRewriter &rewriter) const override; }; +class CastPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = adaptor.getSource(); + Type srcType = src.getType(); + + TypeConverter *converter = getTypeConverter(); + Type dstType = converter->convertType(op.getType()); + if (srcType != dstType) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "types doesn't match: " << srcType << " and " << dstType; + }); + + rewriter.replaceOp(op, src); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -779,10 +801,10 @@ namespace mlir { void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns - .add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); } } // namespace mlir diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -566,3 +566,26 @@ } } // end module + + +// ----- + +// Check casts + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: func.func @cast +// CHECK-SAME: (%[[MEM:.*]]: memref<4x?xf32, #spirv.storage_class>) +func.func @cast(%arg: memref<4x?xf32, #spirv.storage_class>) -> memref> { +// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<4x?xf32, #spirv.storage_class> to !spirv.ptr +// CHECK: %[[MEM2:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr to memref> +// CHECK: return %[[MEM2]] + %ret = memref.cast %arg : memref<4x?xf32, #spirv.storage_class> to memref> + return %ret : memref> +} + +}