diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -290,6 +290,49 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// MulUIExtendedOp +//===----------------------------------------------------------------------===// + +def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative, + AllTypesMatch<["lhs", "rhs", "low", "high"]>]> { + let summary = [{ + extended unsigned integer multiplication operation + }]; + + let description = [{ + Performs (2*N)-bit multiplication on zero-extended operands. Returns two + N-bit results: the low and the high halves of the product. The low half has + the same value as the result of regular multiplication `arith.muli` with + the same operands. + + Example: + + ```mlir + // Scalar multiplication. + %low, %high = arith.mului_extended %a, %b : i32 + + // Vector element-wise multiplication. + %c:2 = arith.mului_extended %d, %e : vector<4xi32> + + // Tensor element-wise multiplication. + %x:2 = arith.mului_extended %y, %z : tensor<4x?xi8> + ``` + }]; + + let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs); + let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)"; + + let hasFolder = 1; + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + ::llvm::Optional<::llvm::SmallVector> getShapeForUnroll(); + }]; +} + //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -142,6 +142,15 @@ ConversionPatternRewriter &rewriter) const override; }; +struct MulUIExtendedOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -261,6 +270,67 @@ "ND vector types are not supported yet"); } +//===----------------------------------------------------------------------===// +// MulUIExtendedOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult MulUIExtendedOpLowering::matchAndRewrite( + arith::MulUIExtendedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type resultType = adaptor.getLhs().getType(); + + if (!LLVM::isCompatibleType(resultType)) + return failure(); + + Location loc = op.getLoc(); + + // Handle the scalar and 1D vector cases. Because LLVM does not have a + // matching extended multiplication intrinsic, perform regular multiplication + // on operands zero-extended to i(2*N) bits, and truncate the results back to + // iN types. + if (!resultType.isa()) { + Type wideType; + // Shift amount necessary to extract the high bits from widened result. + Attribute shiftValAttr; + + if (auto intTy = resultType.dyn_cast()) { + unsigned resultBitwidth = intTy.getWidth(); + wideType = rewriter.getIntegerType(resultBitwidth * 2); + shiftValAttr = rewriter.getIntegerAttr(wideType, resultBitwidth); + } else { + auto vecTy = resultType.cast(); + unsigned resultBitwidth = vecTy.getElementTypeBitWidth(); + wideType = VectorType::get(vecTy.getShape(), + rewriter.getIntegerType(resultBitwidth * 2)); + shiftValAttr = SplatElementsAttr::get( + wideType, APInt(resultBitwidth * 2, resultBitwidth)); + } + assert(LLVM::isCompatibleType(wideType) && + "LLVM dialect should support all signless integer types"); + + Value lhsExt = + rewriter.create(loc, wideType, adaptor.getLhs()); + Value rhsExt = + rewriter.create(loc, wideType, adaptor.getRhs()); + Value mulExt = rewriter.create(loc, wideType, lhsExt, rhsExt); + + // Split the 2*N-bit wide result into two N-bit values. + Value low = rewriter.create(loc, resultType, mulExt); + Value shiftVal = rewriter.create(loc, shiftValAttr); + Value highExt = rewriter.create(loc, mulExt, shiftVal); + Value high = rewriter.create(loc, resultType, highExt); + + rewriter.replaceOp(op, {low, high}); + return success(); + } + + if (!resultType.isa()) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return rewriter.notifyMatchFailure(op, + "ND vector types are not supported yet"); +} + //===----------------------------------------------------------------------===// // CmpIOpLowering //===----------------------------------------------------------------------===// @@ -397,6 +467,7 @@ MinUIOpLowering, MulFOpLowering, MulIOpLowering, + MulUIExtendedOpLowering, NegFOpLowering, OrIOpLowering, RemFOpLowering, diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -223,6 +223,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts arith.mului_extended to spirv.UMulExtended. +class MulUIExtendedOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts arith.select to spirv.Select. class SelectOpPattern final : public OpConversionPattern { public: @@ -944,6 +954,26 @@ return success(); } +//===----------------------------------------------------------------------===// +// MulUIExtendedOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult MulUIExtendedOpPattern::matchAndRewrite( + arith::MulUIExtendedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value result = rewriter.create(loc, adaptor.getLhs(), + adaptor.getRhs()); + + Value low = rewriter.create(loc, result, + llvm::makeArrayRef(0)); + Value high = rewriter.create( + loc, result, llvm::makeArrayRef(1)); + + rewriter.replaceOp(op, {low, high}); + return success(); +} + //===----------------------------------------------------------------------===// // SelectOpPattern //===----------------------------------------------------------------------===// @@ -1040,7 +1070,7 @@ TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern, - AddUIExtendedOpPattern, SelectOpPattern, + AddUIExtendedOpPattern, MulUIExtendedOpPattern, SelectOpPattern, MinMaxFOpPattern, MinMaxFOpPattern, diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -100,6 +100,17 @@ Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x), (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>; +//===----------------------------------------------------------------------===// +// MulUIExtendedOp +//===----------------------------------------------------------------------===// + +// mului_extended(x, y) -> [muli(x, y), x], when the `high` result is unused. +// Since the `high` result it not used, any replacement value will do. +def MulUIExtendedToMulI : + Pattern<(Arith_MulUIExtendedOp:$res $x, $y), + [(Arith_MulIOp $x, $y), (replaceWithValue $x)], + [(Constraint> $res__1)]>; + //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -237,7 +237,7 @@ LogicalResult arith::AddUIExtendedOp::fold(ArrayRef operands, SmallVectorImpl &results) { - auto overflowTy = getOverflow().getType(); + Type overflowTy = getOverflow().getType(); // addui_extended(x, 0) -> x, false if (matchPattern(getRhs(), m_Zero())) { auto overflowZero = APInt::getZero(1); @@ -345,6 +345,60 @@ operands, [](const APInt &a, const APInt &b) { return a * b; }); } +//===----------------------------------------------------------------------===// +// MulUIExtendedOp +//===----------------------------------------------------------------------===// + +Optional> arith::MulUIExtendedOp::getShapeForUnroll() { + if (auto vt = getType(0).dyn_cast()) + return llvm::to_vector<4>(vt.getShape()); + return std::nullopt; +} + +LogicalResult +arith::MulUIExtendedOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + // mului_extended(x, 0) -> 0, 0 + if (matchPattern(getRhs(), m_Zero())) { + results.push_back(getRhs()); + results.push_back(getRhs()); + return success(); + } + + // mului_extended(x, 1) -> x, 0 + if (matchPattern(getRhs(), m_One())) { + Builder builder(getContext()); + Attribute zero = builder.getZeroAttr(getLhs().getType()); + results.push_back(getLhs()); + results.push_back(zero); + return success(); + } + + // mului_extended(cst_a, cst_b) -> cst_low, cst_high + if (Attribute lowAttr = constFoldBinaryOp( + operands, [](const APInt &a, const APInt &b) { return a * b; })) { + // Invoke the constant fold helper again to calculate the 'high' result. + Attribute highAttr = constFoldBinaryOp( + operands, [](const APInt &a, const APInt &b) { + unsigned bitWidth = a.getBitWidth(); + APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); + return fullProduct.extractBits(bitWidth, bitWidth); + }); + assert(highAttr && "Unexpected constant-folding failure"); + + results.push_back(lowAttr); + results.push_back(highAttr); + return success(); + } + + return failure(); +} + +void arith::MulUIExtendedOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -379,6 +379,38 @@ // ----- +// CHECK-LABEL: @mului_extended_scalar +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32) +func.func @mului_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i32) { + // CHECK-NEXT: [[LHS:%.+]] = llvm.zext [[ARG0]] : i32 to i64 + // CHECK-NEXT: [[RHS:%.+]] = llvm.zext [[ARG1]] : i32 to i64 + // CHECK-NEXT: [[MUL:%.+]] = llvm.mul [[LHS]], [[RHS]] : i64 + // CHECK-NEXT: [[LOW:%.+]] = llvm.trunc [[MUL]] : i64 to i32 + // CHECK-NEXT: [[C32:%.+]] = llvm.mlir.constant(32 : i64) : i64 + // CHECK-NEXT: [[SHL:%.+]] = llvm.lshr [[MUL]], [[C32]] : i64 + // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : i64 to i32 + %low, %high = arith.mului_extended %arg0, %arg1 : i32 + // CHECK-NEXT: return [[LOW]], [[HIGH]] : i32, i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @mului_extended_vector1d +// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi64>, [[ARG1:%.+]]: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>) +func.func @mului_extended_vector1d(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>) { + // CHECK-NEXT: [[LHS:%.+]] = llvm.zext [[ARG0]] : vector<3xi64> to vector<3xi128> + // CHECK-NEXT: [[RHS:%.+]] = llvm.zext [[ARG1]] : vector<3xi64> to vector<3xi128> + // CHECK-NEXT: [[MUL:%.+]] = llvm.mul [[LHS]], [[RHS]] : vector<3xi128> + // CHECK-NEXT: [[LOW:%.+]] = llvm.trunc [[MUL]] : vector<3xi128> to vector<3xi64> + // CHECK-NEXT: [[C64:%.+]] = llvm.mlir.constant(dense<64> : vector<3xi128>) : vector<3xi128> + // CHECK-NEXT: [[SHL:%.+]] = llvm.lshr [[MUL]], [[C64]] : vector<3xi128> + // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : vector<3xi128> to vector<3xi64> + %low, %high = arith.mului_extended %arg0, %arg1 : vector<3xi64> + // CHECK-NEXT: return [[LOW]], [[HIGH]] : vector<3xi64>, vector<3xi64> + return %low, %high : vector<3xi64>, vector<3xi64> +} + +// ----- + // CHECK-LABEL: func @cmpf_2dvector( func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) { // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -99,6 +99,29 @@ return %sum, %overflow : vector<4xi32>, vector<4xi1> } +// Check extended unsigned integer multiplication conversions. +// CHECK-LABEL: @int32_scalar_mului_extended +// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) +func.func @int32_scalar_mului_extended(%lhs: i32, %rhs: i32) -> (i32, i32) { + // CHECK-NEXT: %[[MUL:.+]] = spirv.UMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(i32, i32)> + // CHECK-DAG: %[[LOW:.+]] = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(i32, i32)> + // CHECK-DAG: %[[HIGH:.+]] = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(i32, i32)> + // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : i32, i32 + %low, %high = arith.mului_extended %lhs, %rhs: i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @int32_vector_mului_extended +// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>) +func.func @int32_vector_mului_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { + // CHECK-NEXT: %[[MUL:.+]] = spirv.UMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[LOW:.+]] = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[HIGH:.+]] = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : vector<4xi32>, vector<4xi32> + %low, %high = arith.mului_extended %lhs, %rhs: vector<4xi32> + return %low, %high : vector<4xi32>, vector<4xi32> +} + // Check float unary operation conversions. // CHECK-LABEL: @float32_unary_scalar func.func @float32_unary_scalar(%arg0: f32) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -716,6 +716,104 @@ return %sum, %carry : vector<4xi32>, vector<4xi1> } +// CHECK-LABEL: @muluiExtendedZeroRhs +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: return %[[zero]], %[[zero]] +func.func @muluiExtendedZeroRhs(%arg0: i32) -> (i32, i32) { + %zero = arith.constant 0 : i32 + %low, %high = arith.mului_extended %arg0, %zero: i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @muluiExtendedZeroRhsSplat +// CHECK-NEXT: %[[zero:.+]] = arith.constant dense<0> : vector<3xi32> +// CHECK-NEXT: return %[[zero]], %[[zero]] +func.func @muluiExtendedZeroRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) { + %zero = arith.constant dense<0> : vector<3xi32> + %low, %high = arith.mului_extended %arg0, %zero: vector<3xi32> + return %low, %high : vector<3xi32>, vector<3xi32> +} + +// CHECK-LABEL: @muluiExtendedZeroLhs +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: return %[[zero]], %[[zero]] +func.func @muluiExtendedZeroLhs(%arg0: i32) -> (i32, i32) { + %zero = arith.constant 0 : i32 + %low, %high = arith.mului_extended %zero, %arg0: i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @muluiExtendedOneRhs +// CHECK-SAME: (%[[ARG:.+]]: i32) -> (i32, i32) +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: return %[[ARG]], %[[zero]] +func.func @muluiExtendedOneRhs(%arg0: i32) -> (i32, i32) { + %zero = arith.constant 1 : i32 + %low, %high = arith.mului_extended %arg0, %zero: i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @muluiExtendedOneRhsSplat +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) +// CHECK-NEXT: %[[zero:.+]] = arith.constant dense<0> : vector<3xi32> +// CHECK-NEXT: return %[[ARG]], %[[zero]] +func.func @muluiExtendedOneRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) { + %zero = arith.constant dense<1> : vector<3xi32> + %low, %high = arith.mului_extended %arg0, %zero: vector<3xi32> + return %low, %high : vector<3xi32>, vector<3xi32> +} + +// CHECK-LABEL: @muluiExtendedOneLhs +// CHECK-SAME: (%[[ARG:.+]]: i32) -> (i32, i32) +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: return %[[ARG]], %[[zero]] +func.func @muluiExtendedOneLhs(%arg0: i32) -> (i32, i32) { + %zero = arith.constant 1 : i32 + %low, %high = arith.mului_extended %zero, %arg0: i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @muluiExtendedUnusedHigh +// CHECK-SAME: (%[[ARG:.+]]: i32) -> i32 +// CHECK-NEXT: %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32 +// CHECK-NEXT: return %[[RES]] +func.func @muluiExtendedUnusedHigh(%arg0: i32) -> i32 { + %low, %high = arith.mului_extended %arg0, %arg0: i32 + return %low : i32 +} + +// This shouldn't be folded. +// CHECK-LABEL: @muluiExtendedUnusedLow +// CHECK-SAME: (%[[ARG:.+]]: i32) -> i32 +// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[ARG]], %[[ARG]] : i32 +// CHECK-NEXT: return %[[HIGH]] +func.func @muluiExtendedUnusedLow(%arg0: i32) -> i32 { + %low, %high = arith.mului_extended %arg0, %arg0: i32 + return %high : i32 +} + +// CHECK-LABEL: @muluiExtendedScalarConstants +// CHECK-DAG: %[[c157:.+]] = arith.constant -99 : i8 +// CHECK-DAG: %[[c29:.+]] = arith.constant 29 : i8 +// CHECK-NEXT: return %[[c157]], %[[c29]] +func.func @muluiExtendedScalarConstants() -> (i8, i8) { + %c57 = arith.constant 57 : i8 + %c133 = arith.constant 133 : i8 + %low, %high = arith.mului_extended %c57, %c133: i8 // = 7581 + return %low, %high : i8, i8 +} + +// CHECK-LABEL: @muluiExtendedVectorConstants +// CHECK-DAG: %[[cstLo:.+]] = arith.constant dense<[65, 79, 1]> : vector<3xi8> +// CHECK-DAG: %[[cstHi:.+]] = arith.constant dense<[0, 14, -2]> : vector<3xi8> +// CHECK-NEXT: return %[[cstLo]], %[[cstHi]] +func.func @muluiExtendedVectorConstants() -> (vector<3xi8>, vector<3xi8>) { + %cstA = arith.constant dense<[5, 37, 255]> : vector<3xi8> + %cstB = arith.constant dense<[13, 99, 255]> : vector<3xi8> + %low, %high = arith.mului_extended %cstA, %cstB: vector<3xi8> + return %low, %high : vector<3xi8>, vector<3xi8> +} + // CHECK-LABEL: @notCmpEQ // CHECK: %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8 // CHECK: return %[[cres]] diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -97,6 +97,30 @@ return %0 : vector<[8]xi64> } +// CHECK-LABEL: test_mului_extended +func.func @test_mului_extended(%arg0 : i32, %arg1 : i32) -> i32 { + %low, %high = arith.mului_extended %arg0, %arg1 : i32 + return %high : i32 +} + +// CHECK-LABEL: test_mului_extended_tensor +func.func @test_mului_extended_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { + %low, %high = arith.mului_extended %arg0, %arg1 : tensor<8x8xi64> + return %high : tensor<8x8xi64> +} + +// CHECK-LABEL: test_mului_extended_vector +func.func @test_mului_extended_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> { + %0:2 = arith.mului_extended %arg0, %arg1 : vector<8xi64> + return %0#0 : vector<8xi64> +} + +// CHECK-LABEL: test_mului_extended_scalable_vector +func.func @test_mului_extended_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0:2 = arith.mului_extended %arg0, %arg1 : vector<[8]xi64> + return %0#1 : vector<[8]xi64> +} + // CHECK-LABEL: test_divui func.func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.divui %arg0, %arg1 : i64 diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1044,7 +1044,8 @@ } if (offsets.front() > 0) { - const char error[] = "no enough values generated to replace the matched op"; + const char error[] = + "not enough values generated to replace the matched op"; PrintFatalError(loc, error); }