diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -65,6 +65,14 @@ "Type":$elementType ); + let builders = [ + TypeBuilderWithInferredContext<(ins + "ArrayRef":$shape, "Type":$elementType + ), [{ + return $_get(elementType.getContext(), shape, elementType); + }]> + ]; + let printer = [{ $_printer << "vector<"; for (int64_t dim : getShape()) @@ -76,7 +84,7 @@ VectorType vector; if ($_parser.parseType(vector)) return Type(); - return get($_ctxt, vector.getShape(), vector.getElementType()); + return get(vector.getShape(), vector.getElementType()); }]; let extraClassDeclaration = [{ @@ -93,6 +101,9 @@ return num; } }]; + + let skipDefaultBuilders = 1; + let genVerifyDecl = 1; } //===----------------------------------------------------------------------===// @@ -266,6 +277,222 @@ "attr-dict `:` type($res)"; } +def ScalableAddIOp : ArmSVE_Op<"addi", + [Commutative, + AllTypesMatch<["src1", "src2", "dst"]> + ]> { + let summary = "addition for scalable vectors of integers"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arm_sve.addi` ssa-use `,` ssa-use `:` !arm_sve.vector + ``` + + The `arm_sve.addi` operations takes two scalable vectors and returns one + scalable vector with the result of the operation. + }]; + let arguments = (ins + ScalableVectorOf<[I8, I16, I32, I64]>:$src1, + ScalableVectorOf<[I8, I16, I32, I64]>:$src2 + ); + let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + +def ScalableAddFOp : ArmSVE_Op<"addf", + [Commutative, + AllTypesMatch<["src1", "src2", "dst"]> + ]> { + let summary = "addition for scalable vectors of floats"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arm_sve.addf` ssa-use `,` ssa-use `:` !arm_sve.vector + ``` + + The `arm_sve.addf` operations takes two scalable vectors and returns one + scalable vector with the result of the operation. + }]; + let arguments = (ins + ScalableVectorOf<[AnyFloat]>:$src1, + ScalableVectorOf<[AnyFloat]>:$src2 + ); + let results = (outs ScalableVectorOf<[AnyFloat]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + +def ScalableSubIOp : ArmSVE_Op<"subi", + [Commutative, + AllTypesMatch<["src1", "src2", "dst"]> + ]> { + let summary = "subtraction for scalable vectors of integers"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arm_sve.subi` ssa-use `,` ssa-use `:` !arm_sve.vector + ``` + + The `arm_sve.subi` operations takes two scalable vectors and returns one + scalable vector with the result of the operation. + }]; + let arguments = (ins + ScalableVectorOf<[I8, I16, I32, I64]>:$src1, + ScalableVectorOf<[I8, I16, I32, I64]>:$src2 + ); + let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + +def ScalableSubFOp : ArmSVE_Op<"subf", + [Commutative, + AllTypesMatch<["src1", "src2", "dst"]> + ]> { + let summary = "subtraction for scalable vectors of floats"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arm_sve.subf` ssa-use `,` ssa-use `:` !arm_sve.vector + ``` + + The `arm_sve.subf` operations takes two scalable vectors and returns one + scalable vector with the result of the operation. + }]; + let arguments = (ins + ScalableVectorOf<[AnyFloat]>:$src1, + ScalableVectorOf<[AnyFloat]>:$src2 + ); + let results = (outs ScalableVectorOf<[AnyFloat]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + +def ScalableMulIOp : ArmSVE_Op<"muli", + [Commutative, + AllTypesMatch<["src1", "src2", "dst"]> + ]> { + let summary = "multiplication for scalable vectors of integers"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arm_sve.muli` ssa-use `,` ssa-use `:` !arm_sve.vector + ``` + + The `arm_sve.muli` operations takes two scalable vectors and returns one + scalable vector with the result of the operation. + }]; + let arguments = (ins + ScalableVectorOf<[I8, I16, I32, I64]>:$src1, + ScalableVectorOf<[I8, I16, I32, I64]>:$src2 + ); + let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + +def ScalableMulFOp : ArmSVE_Op<"mulf", + [Commutative, + AllTypesMatch<["src1", "src2", "dst"]> + ]> { + let summary = "multiplication for scalable vectors of floats"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arm_sve.mulf` ssa-use `,` ssa-use `:` !arm_sve.vector + ``` + + The `arm_sve.mulf` operations takes two scalable vectors and returns one + scalable vector with the result of the operation. + }]; + let arguments = (ins + ScalableVectorOf<[AnyFloat]>:$src1, + ScalableVectorOf<[AnyFloat]>:$src2 + ); + let results = (outs ScalableVectorOf<[AnyFloat]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + +def ScalableSDivIOp : ArmSVE_Op<"divi_signed", + [Commutative, + AllTypesMatch<["src1", "src2", "dst"]> + ]> { + let summary = "signed division for scalable vectors of integers"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arm_sve.divi_signed` ssa-use `,` ssa-use `:` !arm_sve.vector + ``` + + The `arm_sve.divi_signed` operations takes two scalable vectors and returns + one scalable vector with the result of the operation. + }]; + let arguments = (ins + ScalableVectorOf<[I8, I16, I32, I64]>:$src1, + ScalableVectorOf<[I8, I16, I32, I64]>:$src2 + ); + let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + +def ScalableUDivIOp : ArmSVE_Op<"divi_unsigned", + [Commutative, + AllTypesMatch<["src1", "src2", "dst"]> + ]> { + let summary = "unsigned division for scalable vectors of integers"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arm_sve.divi_unsigned` ssa-use `,` ssa-use `:` !arm_sve.vector + ``` + + The `arm_sve.divi_unsigned` operations takes two scalable vectors and returns + one scalable vector with the result of the operation. + }]; + let arguments = (ins + ScalableVectorOf<[I8, I16, I32, I64]>:$src1, + ScalableVectorOf<[I8, I16, I32, I64]>:$src2 + ); + let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + +def ScalableDivFOp : ArmSVE_Op<"divf", + [Commutative, + AllTypesMatch<["src1", "src2", "dst"]> + ]> { + let summary = "division for scalable vectors of floats"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arm_sve.divf` ssa-use `,` ssa-use `:` !arm_sve.vector + ``` + + The `arm_sve.divf` operations takes two scalable vectors and returns one + scalable vector with the result of the operation. + }]; + let arguments = (ins + ScalableVectorOf<[AnyFloat]>:$src1, + ScalableVectorOf<[AnyFloat]>:$src2 + ); + let results = (outs ScalableVectorOf<[AnyFloat]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + def UmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"ummla">, Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -42,6 +42,20 @@ // ScalableVectorType //===----------------------------------------------------------------------===// +LogicalResult arm_sve::ScalableVectorType::verify( + function_ref emitError, ArrayRef shape, + Type elementType) { + if (shape.empty() || shape.size() > 1) + return emitError() + << "scalable vector types must have exactly one dimension"; + + if (shape[0] <= 0) + return emitError() + << "scalable vector types must have a positive constant size"; + + return success(); +} + Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const { llvm::SMLoc typeLoc = parser.getCurrentLocation(); { diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -84,6 +84,38 @@ using VectorScaleOpLowering = OneToOneConvertToLLVMPattern; +static void +populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns) { + // clang-format off + patterns.add, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern + >(converter); + // clang-format on +} + +static void +configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) { + // clang-format off + target.addIllegalOp(); + // clang-format on +} + /// Populate the given list with patterns that convert from ArmSVE to LLVM. void mlir::populateArmSVELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { @@ -106,20 +138,14 @@ UmmlaOpLowering, VectorScaleOpLowering>(converter); // clang-format on + populateBasicSVEArithmeticExportPatterns(converter, patterns); } void mlir::configureArmSVELegalizeForExportTarget( LLVMConversionTarget &target) { - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); + target.addLegalOp(); + target.addIllegalOp(); auto hasScalableVectorType = [](TypeRange types) { for (Type type : types) if (type.isa()) @@ -135,4 +161,5 @@ return !hasScalableVectorType(op->getOperandTypes()) && !hasScalableVectorType(op->getResultTypes()); }); + configureBasicSVEArithmeticLegalizations(target); } diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -40,6 +40,40 @@ return %0 : !arm_sve.vector<4xi32> } +func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, + %b: !arm_sve.vector<4xi32>, + %c: !arm_sve.vector<4xi32>, + %d: !arm_sve.vector<4xi32>, + %e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { + // CHECK: llvm.mul + %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32> + // CHECK: llvm.add + %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32> + // CHECK: llvm.sub + %2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32> + // CHECK: llvm.sdiv + %3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32> + // CHECK: llvm.udiv + %4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32> + return %3 : !arm_sve.vector<4xi32> +} + +func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, + %b: !arm_sve.vector<4xf32>, + %c: !arm_sve.vector<4xf32>, + %d: !arm_sve.vector<4xf32>, + %e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> { + // CHECK: llvm.fmul + %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32> + // CHECK: llvm.fadd + %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32> + // CHECK: llvm.fsub + %2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32> + // CHECK: llvm.fdiv + %3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32> + return %3 : !arm_sve.vector<4xf32> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vscale %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -36,6 +36,26 @@ return %0 : !arm_sve.vector<4xi32> } +func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, + %b: !arm_sve.vector<4xi32>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { + // CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32> + %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32> + // CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32> + %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32> + return %1 : !arm_sve.vector<4xi32> +} + +func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, + %b: !arm_sve.vector<4xf32>, + %c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> { + // CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32> + %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32> + // CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32> + %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32> + return %1 : !arm_sve.vector<4xf32> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vector_scale : index %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -48,6 +48,30 @@ llvm.return %0 : !llvm.vec } +// CHECK-LABEL: define @arm_sve_arithi +llvm.func @arm_sve_arithi(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: mul + %0 = llvm.mul %arg0, %arg1 : !llvm.vec + // CHECK: add + %1 = llvm.add %0, %arg2 : !llvm.vec + llvm.return %1 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_arithf +llvm.func @arm_sve_arithf(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: fmul + %0 = llvm.fmul %arg0, %arg1 : !llvm.vec + // CHECK: fadd + %1 = llvm.fadd %0, %arg2 : !llvm.vec + llvm.return %1 : !llvm.vec +} + // CHECK-LABEL: define i64 @get_vector_scale() llvm.func @get_vector_scale() -> i64 { // CHECK: call i64 @llvm.vscale.i64()