diff --git a/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp b/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp --- a/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp +++ b/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp @@ -28,6 +28,25 @@ namespace { +struct ConstantOpPattern final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type spirvType = getTypeConverter()->convertType(constOp.getType()); + if (!spirvType) + return rewriter.notifyMatchFailure(constOp, + "unable to convert result type"); + + rewriter.replaceOpWithNewOp( + constOp, spirvType, + DenseElementsAttr::get(cast(spirvType), + constOp.getValue().getValue())); + return success(); + } +}; + struct CreateOpPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -87,6 +106,6 @@ RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(typeConverter, - context); + patterns.add( + typeConverter, context); } diff --git a/mlir/test/Conversion/ComplexToSPIRV/complex-to-spirv.mlir b/mlir/test/Conversion/ComplexToSPIRV/complex-to-spirv.mlir --- a/mlir/test/Conversion/ComplexToSPIRV/complex-to-spirv.mlir +++ b/mlir/test/Conversion/ComplexToSPIRV/complex-to-spirv.mlir @@ -38,3 +38,12 @@ // CHECK: %[[IM:.+]] = spirv.CompositeExtract %[[CAST]][1 : i32] : vector<2xf32> // CHECK: return %[[IM]] : f32 +// ----- + +func.func @complex_const() -> complex { + %cst = complex.constant [0x7FC00000 : f32, 0.000000e+00 : f32] : complex + return %cst : complex +} + +// CHECK-LABEL: func.func @complex_const() +// CHECK: spirv.Constant dense<[0x7FC00000, 0.000000e+00]> : vector<2xf32>