diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -122,6 +122,42 @@ /*list traits=*/traits, /*int numResults=*/1>; +class ScalableFOp traits = []> : + ArmSVE_Op])> { + let summary = op_description # " for scalable vectors of floats"; + let description = [{ + The `arm_sve.}] # mnemonic # [{` operations takes two scalable vectors and + returns one scalable vector with the result of the }] # op_description # [{. + }]; + let arguments = (ins + ScalableVectorOf<[AnyFloat]>:$src1, + ScalableVectorOf<[AnyFloat]>:$src2 + ); + let results = (outs ScalableVectorOf<[AnyFloat]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + +class ScalableIOp traits = []> : + ArmSVE_Op])> { + let summary = op_description # " for scalable vectors of integers"; + let description = [{ + The `arm_sve.}] # mnemonic # [{` operation takes two scalable vectors and + returns one scalable vector with the result of the }] # op_description # [{. + }]; + let arguments = (ins + ScalableVectorOf<[I8, I16, I32, I64]>:$src1, + ScalableVectorOf<[I8, I16, I32, I64]>:$src2 + ); + let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + def SdotOp : ArmSVE_Op<"sdot", [NoSideEffect, AllTypesMatch<["src1", "src2"]>, @@ -266,6 +302,25 @@ "attr-dict `:` type($res)"; } + +def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>; + +def ScalableAddFOp : ScalableFOp<"addf", "addition", [Commutative]>; + +def ScalableSubIOp : ScalableIOp<"subi", "subtraction">; + +def ScalableSubFOp : ScalableFOp<"subf", "subtraction">; + +def ScalableMulIOp : ScalableIOp<"muli", "multiplication", [Commutative]>; + +def ScalableMulFOp : ScalableFOp<"mulf", "multiplication", [Commutative]>; + +def ScalableSDivIOp : ScalableIOp<"divi_signed", "signed division">; + +def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">; + +def ScalableDivFOp : ScalableFOp<"divf", "division">; + def UmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"ummla">, Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -84,6 +84,38 @@ using VectorScaleOpLowering = OneToOneConvertToLLVMPattern; +static void +populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns) { + // clang-format off + patterns.add, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern + >(converter); + // clang-format on +} + +static void +configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) { + // clang-format off + target.addIllegalOp(); + // clang-format on +} + /// Populate the given list with patterns that convert from ArmSVE to LLVM. void mlir::populateArmSVELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { @@ -106,20 +138,14 @@ UmmlaOpLowering, VectorScaleOpLowering>(converter); // clang-format on + populateBasicSVEArithmeticExportPatterns(converter, patterns); } void mlir::configureArmSVELegalizeForExportTarget( LLVMConversionTarget &target) { - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); + target.addLegalOp(); + target.addIllegalOp(); auto hasScalableVectorType = [](TypeRange types) { for (Type type : types) if (type.isa()) @@ -135,4 +161,5 @@ return !hasScalableVectorType(op->getOperandTypes()) && !hasScalableVectorType(op->getResultTypes()); }); + configureBasicSVEArithmeticLegalizations(target); } diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -40,6 +40,40 @@ return %0 : !arm_sve.vector<4xi32> } +func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, + %b: !arm_sve.vector<4xi32>, + %c: !arm_sve.vector<4xi32>, + %d: !arm_sve.vector<4xi32>, + %e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { + // CHECK: llvm.mul {{.*}}: !llvm.vec + %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32> + // CHECK: llvm.add {{.*}}: !llvm.vec + %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32> + // CHECK: llvm.sub {{.*}}: !llvm.vec + %2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32> + // CHECK: llvm.sdiv {{.*}}: !llvm.vec + %3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32> + // CHECK: llvm.udiv {{.*}}: !llvm.vec + %4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32> + return %3 : !arm_sve.vector<4xi32> +} + +func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, + %b: !arm_sve.vector<4xf32>, + %c: !arm_sve.vector<4xf32>, + %d: !arm_sve.vector<4xf32>, + %e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> { + // CHECK: llvm.fmul {{.*}}: !llvm.vec + %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32> + // CHECK: llvm.fadd {{.*}}: !llvm.vec + %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32> + // CHECK: llvm.fsub {{.*}}: !llvm.vec + %2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32> + // CHECK: llvm.fdiv {{.*}}: !llvm.vec + %3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32> + return %3 : !arm_sve.vector<4xf32> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vscale %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -36,6 +36,26 @@ return %0 : !arm_sve.vector<4xi32> } +func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, + %b: !arm_sve.vector<4xi32>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { + // CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32> + %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32> + // CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32> + %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32> + return %1 : !arm_sve.vector<4xi32> +} + +func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, + %b: !arm_sve.vector<4xf32>, + %c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> { + // CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32> + %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32> + // CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32> + %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32> + return %1 : !arm_sve.vector<4xf32> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vector_scale : index %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -48,6 +48,30 @@ llvm.return %0 : !llvm.vec } +// CHECK-LABEL: define @arm_sve_arithi +llvm.func @arm_sve_arithi(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: mul + %0 = llvm.mul %arg0, %arg1 : !llvm.vec + // CHECK: add + %1 = llvm.add %0, %arg2 : !llvm.vec + llvm.return %1 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_arithf +llvm.func @arm_sve_arithf(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: fmul + %0 = llvm.fmul %arg0, %arg1 : !llvm.vec + // CHECK: fadd + %1 = llvm.fadd %0, %arg2 : !llvm.vec + llvm.return %1 : !llvm.vec +} + // CHECK-LABEL: define i64 @get_vector_scale() llvm.func @get_vector_scale() -> i64 { // CHECK: call i64 @llvm.vscale.i64()