diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -268,7 +268,7 @@ let hasCanonicalizer = 1; let extraClassDeclaration = [{ - ::llvm::Optional<::llvm::SmallVector> getShapeForUnroll(); + Optional> getShapeForUnroll(); }]; } @@ -291,6 +291,49 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// MulSIExtendedOp +//===----------------------------------------------------------------------===// + +def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative, + AllTypesMatch<["lhs", "rhs", "low", "high"]>]> { + let summary = [{ + extended signed integer multiplication operation + }]; + + let description = [{ + Performs (2*N)-bit multiplication on sign-extended operands. Returns two + N-bit results: the low and the high halves of the product. The low half has + the same value as the result of regular multiplication `arith.muli` with + the same operands. + + Example: + + ```mlir + // Scalar multiplication. + %low, %high = arith.mulsi_extended %a, %b : i32 + + // Vector element-wise multiplication. + %c:2 = arith.mulsi_extended %d, %e : vector<4xi32> + + // Tensor element-wise multiplication. + %x:2 = arith.mulsi_extended %y, %z : tensor<4x?xi8> + ``` + }]; + + let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs); + let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)"; + + let hasFolder = 1; + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + Optional> getShapeForUnroll(); + }]; +} + //===----------------------------------------------------------------------===// // MulUIExtendedOp //===----------------------------------------------------------------------===// @@ -330,7 +373,7 @@ let hasCanonicalizer = 1; let extraClassDeclaration = [{ - ::llvm::Optional<::llvm::SmallVector> getShapeForUnroll(); + Optional> getShapeForUnroll(); }]; } diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" +#include namespace mlir { #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS @@ -142,15 +143,20 @@ ConversionPatternRewriter &rewriter) const override; }; -struct MulUIExtendedOpLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +template +struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor, + matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; +using MulSIExtendedOpLowering = + MulIExtendedOpLowering; +using MulUIExtendedOpLowering = + MulIExtendedOpLowering; + struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -271,11 +277,12 @@ } //===----------------------------------------------------------------------===// -// MulUIExtendedOpLowering +// MulIExtendedOpLowering //===----------------------------------------------------------------------===// -LogicalResult MulUIExtendedOpLowering::matchAndRewrite( - arith::MulUIExtendedOp op, OpAdaptor adaptor, +template +LogicalResult MulIExtendedOpLowering::matchAndRewrite( + ArithMulOp op, typename ArithMulOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { Type resultType = adaptor.getLhs().getType(); @@ -308,10 +315,9 @@ assert(LLVM::isCompatibleType(wideType) && "LLVM dialect should support all signless integer types"); - Value lhsExt = - rewriter.create(loc, wideType, adaptor.getLhs()); - Value rhsExt = - rewriter.create(loc, wideType, adaptor.getRhs()); + using LLVMExtOp = std::conditional_t; + Value lhsExt = rewriter.create(loc, wideType, adaptor.getLhs()); + Value rhsExt = rewriter.create(loc, wideType, adaptor.getRhs()); Value mulExt = rewriter.create(loc, wideType, lhsExt, rhsExt); // Split the 2*N-bit wide result into two N-bit values. @@ -467,6 +473,7 @@ MinUIOpLowering, MulFOpLowering, MulIOpLowering, + MulSIExtendedOpLowering, MulUIExtendedOpLowering, NegFOpLowering, OrIOpLowering, diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -223,13 +223,13 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts arith.mului_extended to spirv.UMulExtended. -class MulUIExtendedOpPattern final - : public OpConversionPattern { +/// Converts arith.mul*i_extended to spirv.*MulExtended. +template +class MulIExtendedOpPattern final : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor, + matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -955,15 +955,16 @@ } //===----------------------------------------------------------------------===// -// MulUIExtendedOpPattern +// MulIExtendedOpPattern //===----------------------------------------------------------------------===// -LogicalResult MulUIExtendedOpPattern::matchAndRewrite( - arith::MulUIExtendedOp op, OpAdaptor adaptor, +template +LogicalResult MulIExtendedOpPattern::matchAndRewrite( + ArithMulOp op, typename ArithMulOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); - Value result = rewriter.create(loc, adaptor.getLhs(), - adaptor.getRhs()); + Value result = + rewriter.create(loc, adaptor.getLhs(), adaptor.getRhs()); Value low = rewriter.create(loc, result, llvm::makeArrayRef(0)); @@ -1070,7 +1071,10 @@ TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern, - AddUIExtendedOpPattern, MulUIExtendedOpPattern, SelectOpPattern, + AddUIExtendedOpPattern, + MulIExtendedOpPattern, + MulIExtendedOpPattern, + SelectOpPattern, MinMaxFOpPattern, MinMaxFOpPattern, diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -111,6 +111,17 @@ Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x), (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>; +//===----------------------------------------------------------------------===// +// MulSIExtendedOp +//===----------------------------------------------------------------------===// + +// mulsi_extended(x, y) -> [muli(x, y), x], when the `high` result is unused. +// Since the `high` result it not used, any replacement value will do. +def MulSIExtendedToMulI : + Pattern<(Arith_MulSIExtendedOp:$res $x, $y), + [(Arith_MulIOp $x, $y), (replaceWithValue $x)], + [(Constraint> $res__1)]>; + //===----------------------------------------------------------------------===// // MulUIExtendedOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -350,6 +350,52 @@ operands, [](const APInt &a, const APInt &b) { return a * b; }); } +//===----------------------------------------------------------------------===// +// MulSIExtendedOp +//===----------------------------------------------------------------------===// + +Optional> arith::MulSIExtendedOp::getShapeForUnroll() { + if (auto vt = getType(0).dyn_cast()) + return llvm::to_vector<4>(vt.getShape()); + return std::nullopt; +} + +LogicalResult +arith::MulSIExtendedOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + // mulsi_extended(x, 0) -> 0, 0 + if (matchPattern(getRhs(), m_Zero())) { + Attribute zero = operands[1]; + results.push_back(zero); + results.push_back(zero); + return success(); + } + + // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high + if (Attribute lowAttr = constFoldBinaryOp( + operands, [](const APInt &a, const APInt &b) { return a * b; })) { + // Invoke the constant fold helper again to calculate the 'high' result. + Attribute highAttr = constFoldBinaryOp( + operands, [](const APInt &a, const APInt &b) { + unsigned bitWidth = a.getBitWidth(); + APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); + return fullProduct.extractBits(bitWidth, bitWidth); + }); + assert(highAttr && "Unexpected constant-folding failure"); + + results.push_back(lowAttr); + results.push_back(highAttr); + return success(); + } + + return failure(); +} + +void arith::MulSIExtendedOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // MulUIExtendedOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -379,6 +379,38 @@ // ----- +// CHECK-LABEL: @mulsi_extended_scalar +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32) +func.func @mulsi_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i32) { + // CHECK-NEXT: [[LHS:%.+]] = llvm.sext [[ARG0]] : i32 to i64 + // CHECK-NEXT: [[RHS:%.+]] = llvm.sext [[ARG1]] : i32 to i64 + // CHECK-NEXT: [[MUL:%.+]] = llvm.mul [[LHS]], [[RHS]] : i64 + // CHECK-NEXT: [[LOW:%.+]] = llvm.trunc [[MUL]] : i64 to i32 + // CHECK-NEXT: [[C32:%.+]] = llvm.mlir.constant(32 : i64) : i64 + // CHECK-NEXT: [[SHL:%.+]] = llvm.lshr [[MUL]], [[C32]] : i64 + // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : i64 to i32 + %low, %high = arith.mulsi_extended %arg0, %arg1 : i32 + // CHECK-NEXT: return [[LOW]], [[HIGH]] : i32, i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @mulsi_extended_vector1d +// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi64>, [[ARG1:%.+]]: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>) +func.func @mulsi_extended_vector1d(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>) { + // CHECK-NEXT: [[LHS:%.+]] = llvm.sext [[ARG0]] : vector<3xi64> to vector<3xi128> + // CHECK-NEXT: [[RHS:%.+]] = llvm.sext [[ARG1]] : vector<3xi64> to vector<3xi128> + // CHECK-NEXT: [[MUL:%.+]] = llvm.mul [[LHS]], [[RHS]] : vector<3xi128> + // CHECK-NEXT: [[LOW:%.+]] = llvm.trunc [[MUL]] : vector<3xi128> to vector<3xi64> + // CHECK-NEXT: [[C64:%.+]] = llvm.mlir.constant(dense<64> : vector<3xi128>) : vector<3xi128> + // CHECK-NEXT: [[SHL:%.+]] = llvm.lshr [[MUL]], [[C64]] : vector<3xi128> + // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : vector<3xi128> to vector<3xi64> + %low, %high = arith.mulsi_extended %arg0, %arg1 : vector<3xi64> + // CHECK-NEXT: return [[LOW]], [[HIGH]] : vector<3xi64>, vector<3xi64> + return %low, %high : vector<3xi64>, vector<3xi64> +} + +// ----- + // CHECK-LABEL: @mului_extended_scalar // CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32) func.func @mului_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i32) { diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -99,6 +99,29 @@ return %sum, %overflow : vector<4xi32>, vector<4xi1> } +// Check extended signed integer multiplication conversions. +// CHECK-LABEL: @int32_scalar_mulsi_extended +// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) +func.func @int32_scalar_mulsi_extended(%lhs: i32, %rhs: i32) -> (i32, i32) { + // CHECK-NEXT: %[[MUL:.+]] = spirv.SMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(i32, i32)> + // CHECK-DAG: %[[LOW:.+]] = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(i32, i32)> + // CHECK-DAG: %[[HIGH:.+]] = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(i32, i32)> + // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : i32, i32 + %low, %high = arith.mulsi_extended %lhs, %rhs: i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @int32_vector_mulsi_extended +// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>) +func.func @int32_vector_mulsi_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { + // CHECK-NEXT: %[[MUL:.+]] = spirv.SMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[LOW:.+]] = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[HIGH:.+]] = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : vector<4xi32>, vector<4xi32> + %low, %high = arith.mulsi_extended %lhs, %rhs: vector<4xi32> + return %low, %high : vector<4xi32>, vector<4xi32> +} + // Check extended unsigned integer multiplication conversions. // CHECK-LABEL: @int32_scalar_mului_extended // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -734,6 +734,64 @@ return %sum, %overflow : vector<4xi32>, vector<4xi1> } +// CHECK-LABEL: @mulsiExtendedZeroRhs +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: return %[[zero]], %[[zero]] +func.func @mulsiExtendedZeroRhs(%arg0: i32) -> (i32, i32) { + %zero = arith.constant 0 : i32 + %low, %high = arith.mulsi_extended %arg0, %zero: i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @mulsiExtendedZeroRhsSplat +// CHECK-NEXT: %[[zero:.+]] = arith.constant dense<0> : vector<3xi32> +// CHECK-NEXT: return %[[zero]], %[[zero]] +func.func @mulsiExtendedZeroRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) { + %zero = arith.constant dense<0> : vector<3xi32> + %low, %high = arith.mulsi_extended %arg0, %zero: vector<3xi32> + return %low, %high : vector<3xi32>, vector<3xi32> +} + +// CHECK-LABEL: @mulsiExtendedZeroLhs +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: return %[[zero]], %[[zero]] +func.func @mulsiExtendedZeroLhs(%arg0: i32) -> (i32, i32) { + %zero = arith.constant 0 : i32 + %low, %high = arith.mulsi_extended %zero, %arg0: i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @mulsiExtendedUnusedHigh +// CHECK-SAME: (%[[ARG:.+]]: i32) -> i32 +// CHECK-NEXT: %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32 +// CHECK-NEXT: return %[[RES]] +func.func @mulsiExtendedUnusedHigh(%arg0: i32) -> i32 { + %low, %high = arith.mulsi_extended %arg0, %arg0: i32 + return %low : i32 +} + +// CHECK-LABEL: @mulsiExtendedScalarConstants +// CHECK-DAG: %[[c27:.+]] = arith.constant 27 : i8 +// CHECK-DAG: %[[c_n3:.+]] = arith.constant -3 : i8 +// CHECK-NEXT: return %[[c27]], %[[c_n3]] +func.func @mulsiExtendedScalarConstants() -> (i8, i8) { + %c57 = arith.constant 57 : i8 + %c_n13 = arith.constant -13 : i8 + %low, %high = arith.mulsi_extended %c57, %c_n13: i8 + return %low, %high : i8, i8 +} + +// CHECK-LABEL: @mulsiExtendedVectorConstants +// CHECK-DAG: %[[cstLo:.+]] = arith.constant dense<[65, 79, 34]> : vector<3xi8> +// CHECK-DAG: %[[cstHi:.+]] = arith.constant dense<[0, 14, 0]> : vector<3xi8> +// CHECK-NEXT: return %[[cstLo]], %[[cstHi]] +func.func @mulsiExtendedVectorConstants() -> (vector<3xi8>, vector<3xi8>) { + %cstA = arith.constant dense<[5, 37, -17]> : vector<3xi8> + %cstB = arith.constant dense<[13, 99, -2]> : vector<3xi8> + %low, %high = arith.mulsi_extended %cstA, %cstB: vector<3xi8> + return %low, %high : vector<3xi8>, vector<3xi8> +} + // CHECK-LABEL: @muluiExtendedZeroRhs // CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32 // CHECK-NEXT: return %[[zero]], %[[zero]] diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -97,6 +97,30 @@ return %0 : vector<[8]xi64> } +// CHECK-LABEL: test_mulsi_extended +func.func @test_mulsi_extended(%arg0 : i32, %arg1 : i32) -> i32 { + %low, %high = arith.mulsi_extended %arg0, %arg1 : i32 + return %high : i32 +} + +// CHECK-LABEL: test_mulsi_extended_tensor +func.func @test_mulsi_extended_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { + %low, %high = arith.mulsi_extended %arg0, %arg1 : tensor<8x8xi64> + return %high : tensor<8x8xi64> +} + +// CHECK-LABEL: test_mulsi_extended_vector +func.func @test_mulsi_extended_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> { + %0:2 = arith.mulsi_extended %arg0, %arg1 : vector<8xi64> + return %0#0 : vector<8xi64> +} + +// CHECK-LABEL: test_mulsi_extended_scalable_vector +func.func @test_mulsi_extended_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0:2 = arith.mulsi_extended %arg0, %arg1 : vector<[8]xi64> + return %0#1 : vector<[8]xi64> +} + // CHECK-LABEL: test_mului_extended func.func @test_mului_extended(%arg0 : i32, %arg1 : i32) -> i32 { %low, %high = arith.mului_extended %arg0, %arg1 : i32