diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -354,6 +354,14 @@ SPV_AnyStruct:$result ); + let builders = [ + OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{ + build($_builder, $_state, + ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}), + operand1, operand2); + }]> + ]; + let hasVerifier = 1; } @@ -485,6 +493,14 @@ SPV_AnyStruct:$result ); + let builders = [ + OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{ + build($_builder, $_state, + ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}), + operand1, operand2); + }]> + ]; + let hasVerifier = 1; } diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -854,11 +854,9 @@ AddICarryOpPattern::matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type dstElemTy = adaptor.getLhs().getType(); - auto resultTy = spirv::StructType::get({dstElemTy, dstElemTy}); - Location loc = op->getLoc(); Value result = rewriter.create( - loc, resultTy, adaptor.getLhs(), adaptor.getRhs()); + loc, adaptor.getLhs(), adaptor.getRhs()); Value sumResult = rewriter.create( loc, result, llvm::makeArrayRef(0));