diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -273,7 +273,7 @@ arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto srcType = constOp.getType().dyn_cast(); - if (!srcType) + if (!srcType || srcType.getNumElements() == 1) return failure(); // arith.constant should only have vector or tenor types. @@ -358,16 +358,25 @@ arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type srcType = constOp.getType(); + if (auto shapedType = srcType.dyn_cast()) { + if (shapedType.getNumElements() != 1) + return failure(); + srcType = shapedType.getElementType(); + } if (!srcType.isIntOrIndexOrFloat()) return failure(); + Attribute cstAttr = constOp.getValue(); + if (cstAttr.getType().isa()) + cstAttr = cstAttr.cast().getSplatValue(); + Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); // Floating-point types. if (srcType.isa()) { - auto srcAttr = constOp.getValue().cast(); + auto srcAttr = cstAttr.cast(); auto dstAttr = srcAttr; // Floating-point types not supported in the target environment are all @@ -386,7 +395,7 @@ if (srcType.isInteger(1)) { // arith.constant can use 0/1 instead of true/false for i1 values. We need // to handle that here. - auto dstAttr = convertBoolAttr(constOp.getValue(), rewriter); + auto dstAttr = convertBoolAttr(cstAttr, rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); @@ -395,7 +404,7 @@ // IndexType or IntegerType. Index values are converted to 32-bit integer // values when converting to SPIR-V. - auto srcAttr = constOp.getValue().cast(); + auto srcAttr = cstAttr.cast(); auto dstAttr = convertIntegerAttr(srcAttr, dstType.cast(), rewriter); if (!dstAttr) diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -446,6 +446,17 @@ return } +// CHECK-LABEL: @constant_size1 +func @constant_size1() { + // CHECK: spv.Constant true + %0 = arith.constant dense : tensor<1xi1> + // CHECK: spv.Constant 4 : i64 + %1 = arith.constant dense<4> : vector<1xi64> + // CHECK: spv.Constant 5.000000e+00 : f64 + %2 = arith.constant dense<5.0> : tensor<1xf64> + return +} + } // end module // ----- @@ -485,6 +496,15 @@ return } +// CHECK-LABEL: @constant_size1 +func @constant_size1() { + // CHECK: spv.Constant 4 : i32 + %0 = arith.constant dense<4> : vector<1xi64> + // CHECK: spv.Constant 5.000000e+00 : f32 + %1 = arith.constant dense<5.0> : tensor<1xf64> + return +} + // CHECK-LABEL: @corner_cases func @corner_cases() { // CHECK: %{{.*}} = spv.Constant -1 : i32