diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -398,6 +398,169 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">; +//===----------------------------------------------------------------------===// +// ScalableCmpFOp +//===----------------------------------------------------------------------===// + +// The predicate indicates the type of the comparison to perform: +// (un)orderedness, (in)equality and less/greater than (or equal to) as +// well as predicates that are always true or false. +def CMPF_P_FALSE : I64EnumAttrCase<"AlwaysFalse", 0, "false">; +def CMPF_P_OEQ : I64EnumAttrCase<"OEQ", 1, "oeq">; +def CMPF_P_OGT : I64EnumAttrCase<"OGT", 2, "ogt">; +def CMPF_P_OGE : I64EnumAttrCase<"OGE", 3, "oge">; +def CMPF_P_OLT : I64EnumAttrCase<"OLT", 4, "olt">; +def CMPF_P_OLE : I64EnumAttrCase<"OLE", 5, "ole">; +def CMPF_P_ONE : I64EnumAttrCase<"ONE", 6, "one">; +def CMPF_P_ORD : I64EnumAttrCase<"ORD", 7, "ord">; +def CMPF_P_UEQ : I64EnumAttrCase<"UEQ", 8, "ueq">; +def CMPF_P_UGT : I64EnumAttrCase<"UGT", 9, "ugt">; +def CMPF_P_UGE : I64EnumAttrCase<"UGE", 10, "uge">; +def CMPF_P_ULT : I64EnumAttrCase<"ULT", 11, "ult">; +def CMPF_P_ULE : I64EnumAttrCase<"ULE", 12, "ule">; +def CMPF_P_UNE : I64EnumAttrCase<"UNE", 13, "une">; +def CMPF_P_UNO : I64EnumAttrCase<"UNO", 14, "uno">; +def CMPF_P_TRUE : I64EnumAttrCase<"AlwaysTrue", 15, "true">; + +def CmpFPredicateAttr : I64EnumAttr< + "CmpFPredicate", "", + [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE, + CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT, + CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> { + let cppNamespace = "::mlir::arm_sve"; +} + +def ScalableCmpFOp : ArmSVE_Op<"cmpf", [NoSideEffect, SameTypeOperands, + TypesMatchWith<"result type has i1 element type and same shape as operands", + "lhs", "result", "getI1SameShape($_self)">]> { + let summary = "floating-point comparison operation for scalable vectors"; + let description = [{ + The `arm_sve.cmpf` operation compares two scalable vectors of floating point + elements according to the float comparison rules and the predicate specified + by the respective attribute. The predicate defines the type of comparison: + (un)orderedness, (in)equality and signed less/greater than (or equal to) as + well as predicates that are always true or false. The result is a scalable + vector of i1 elements. Unlike `arm_sve.cmpi`, the operands are always + treated as signed. The u prefix indicates *unordered* comparison, not + unsigned comparison, so "une" means unordered or not equal. For the sake of + readability by humans, custom assembly form for the operation uses a + string-typed attribute for the predicate. The value of this attribute + corresponds to lower-cased name of the predicate constant, e.g., "one" means + "ordered not equal". The string representation of the attribute is merely a + syntactic sugar and is converted to an integer attribute by the parser. + + Example: + + ```mlir + %r = arm_sve.cmpf oeq, %0, %1 : !arm_sve.vector<4xf32> + ``` + }]; + let arguments = (ins + CmpFPredicateAttr:$predicate, + ScalableVectorOf<[AnyFloat]>:$lhs, + ScalableVectorOf<[AnyFloat]>:$rhs // TODO: This should support a simple scalar + ); + let results = (outs ScalableVectorOf<[I1]>:$result); + + let builders = [ + OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, + "Value":$rhs), [{ + buildScalableCmpFOp($_builder, $_state, predicate, lhs, rhs); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpFPredicate getPredicateByName(StringRef name); + + CmpFPredicate getPredicate() { + return (CmpFPredicate)(*this)->getAttrOfType( + getPredicateAttrName()).getInt(); + } + }]; + + let verifier = [{ return success(); }]; + + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; +} + +//===----------------------------------------------------------------------===// +// ScalableCmpIOp +//===----------------------------------------------------------------------===// + +def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>; +def CMPI_P_NE : I64EnumAttrCase<"ne", 1>; +def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>; +def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>; +def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>; +def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>; +def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>; +def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>; +def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>; +def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>; + +def CmpIPredicateAttr : I64EnumAttr< + "CmpIPredicate", "", + [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT, + CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> { + let cppNamespace = "::mlir::arm_sve"; +} + +def ScalableCmpIOp : ArmSVE_Op<"cmpi", [NoSideEffect, SameTypeOperands, + TypesMatchWith<"result type has i1 element type and same shape as operands", + "lhs", "result", "getI1SameShape($_self)">]> { + let summary = "integer comparison operation for scalable vectors"; + let description = [{ + The `arm_sve.cmpi` operation compares two scalable vectors of integer + elements according to the predicate specified by the respective attribute. + + The predicate defines the type of comparison: + + - equal (mnemonic: `"eq"`; integer value: `0`) + - not equal (mnemonic: `"ne"`; integer value: `1`) + - signed less than (mnemonic: `"slt"`; integer value: `2`) + - signed less than or equal (mnemonic: `"sle"`; integer value: `3`) + - signed greater than (mnemonic: `"sgt"`; integer value: `4`) + - signed greater than or equal (mnemonic: `"sge"`; integer value: `5`) + - unsigned less than (mnemonic: `"ult"`; integer value: `6`) + - unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`) + - unsigned greater than (mnemonic: `"ugt"`; integer value: `8`) + - unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`) + + Example: + + ```mlir + %r = arm_sve.cmpf uge, %0, %1 : !arm_sve.vector<4xi32> + ``` + }]; + + let arguments = (ins + CmpIPredicateAttr:$predicate, + ScalableVectorOf<[I8, I16, I32, I64]>:$lhs, + ScalableVectorOf<[I8, I16, I32, I64]>:$rhs + ); + let results = (outs ScalableVectorOf<[I1]>:$result); + + let builders = [ + OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, + "Value":$rhs), [{ + buildScalableCmpIOp($_builder, $_state, predicate, lhs, rhs); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpIPredicate getPredicateByName(StringRef name); + + CmpIPredicate getPredicate() { + return (CmpIPredicate)(*this)->getAttrOfType( + getPredicateAttrName()).getInt(); + } + }]; + + let verifier = [{ return success(); }]; + + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; +} + def UmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"ummla">, Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h @@ -19,6 +19,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc" +#include "mlir/Dialect/ArmSVE/ArmSVEOpsEnums.h.inc" #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc" diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt --- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -3,4 +3,7 @@ set(LLVM_TARGET_DEFINITIONS ArmSVE.td) mlir_tablegen(ArmSVEConversions.inc -gen-llvmir-conversions) +#mlir_tablegen(ArmSVEOpsDialect.h.inc -gen-dialect-decls) +mlir_tablegen(ArmSVEOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(ArmSVEOpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRArmSVEConversionsIncGen) diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -19,9 +19,18 @@ #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/TypeSwitch.h" +#include "mlir/Dialect/ArmSVE/ArmSVEOpsEnums.cpp.inc" + using namespace mlir; +using namespace arm_sve; static Type getI1SameShape(Type type); +static void buildScalableCmpIOp(OpBuilder &build, OperationState &result, + arm_sve::CmpIPredicate predicate, Value lhs, + Value rhs); +static void buildScalableCmpFOp(OpBuilder &build, OperationState &result, + arm_sve::CmpFPredicate predicate, Value lhs, + Value rhs); #define GET_OP_CLASSES #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" @@ -29,7 +38,7 @@ #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc" -void arm_sve::ArmSVEDialect::initialize() { +void ArmSVEDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" @@ -44,7 +53,7 @@ // ScalableVectorType //===----------------------------------------------------------------------===// -Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const { +Type ArmSVEDialect::parseType(DialectAsmParser &parser) const { llvm::SMLoc typeLoc = parser.getCurrentLocation(); { Type genType; @@ -57,7 +66,7 @@ return Type(); } -void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const { +void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const { if (failed(generatedTypePrinter(type, os))) llvm_unreachable("unexpected 'arm_sve' type kind"); } @@ -69,8 +78,30 @@ // Return the scalable vector of the same shape and containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto sVectorType = type.dyn_cast()) - return arm_sve::ScalableVectorType::get(type.getContext(), - sVectorType.getShape(), i1Type); + if (auto sVectorType = type.dyn_cast()) + return ScalableVectorType::get(type.getContext(), sVectorType.getShape(), + i1Type); return nullptr; } + +//===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// + +static void buildScalableCmpFOp(OpBuilder &build, OperationState &result, + arm_sve::CmpFPredicate predicate, Value lhs, + Value rhs) { + result.addOperands({lhs, rhs}); + result.types.push_back(getI1SameShape(lhs.getType())); + result.addAttribute(ScalableCmpFOp::getPredicateAttrName(), + build.getI64IntegerAttr(static_cast(predicate))); +} + +static void buildScalableCmpIOp(OpBuilder &build, OperationState &result, + arm_sve::CmpIPredicate predicate, Value lhs, + Value rhs) { + result.addOperands({lhs, rhs}); + result.types.push_back(getI1SameShape(lhs.getType())); + result.addAttribute(ScalableCmpIOp::getPredicateAttrName(), + build.getI64IntegerAttr(static_cast(predicate))); +} \ No newline at end of file diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -143,6 +143,24 @@ // clang-format on } +static void +populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns) { + // clang-format off + patterns.add, + OneToOneConvertToLLVMPattern + >(converter); + // clang-format on +} + +static void +configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) { + // clang-format off + target.addIllegalOp(); + // clang-format on +} + /// Populate the given list with patterns that convert from ArmSVE to LLVM. void mlir::populateArmSVELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { @@ -175,6 +193,7 @@ ScalableMaskedDivFOpLowering>(converter); // clang-format on populateBasicSVEArithmeticExportPatterns(converter, patterns); + populateSVEMaskGenerationExportPatterns(converter, patterns); } void mlir::configureArmSVELegalizeForExportTarget( @@ -225,4 +244,5 @@ !hasScalableVectorType(op->getResultTypes()); }); configureBasicSVEArithmeticLegalizations(target); + configureSVEMaskGenerationLegalizations(target); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -134,11 +134,15 @@ if (!isCompatibleType(type)) return parser.emitError(trailingTypeLoc, "expected LLVM dialect-compatible type"); - if (LLVM::isCompatibleVectorType(type)) - resultType = LLVM::getFixedVectorType( - resultType, LLVM::getVectorNumElements(type).getFixedValue()); - assert(!type.isa() && - "unhandled scalable vector"); + if (LLVM::isCompatibleVectorType(type)) { + if (type.isa()) { + resultType = LLVM::getFixedVectorType( + resultType, LLVM::getVectorNumElements(type).getFixedValue()); + } else if (type.isa()) { + resultType = LLVM::LLVMScalableVectorType::get( + resultType, LLVM::getVectorNumElements(type).getKnownMinValue()); + } + } result.addTypes({resultType}); return success(); diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -121,6 +121,22 @@ return %3 : !arm_sve.vector<4xf32> } +func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>, + %b: !arm_sve.vector<4xf32>) + -> !arm_sve.vector<4xi1> { + // CHECK: llvm.fcmp "oeq" {{.*}}: !llvm.vec + %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32> + return %0 : !arm_sve.vector<4xi1> +} + +func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>, + %b: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi1> { + // CHECK: llvm.icmp "uge" {{.*}}: !llvm.vec + %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi1> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vscale %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -103,6 +103,22 @@ return %3 : !arm_sve.vector<4xf32> } +func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>, + %b: !arm_sve.vector<4xf32>) + -> !arm_sve.vector<4xi1> { + // CHECK: arm_sve.cmpf oeq, {{.*}}: !arm_sve.vector<4xf32> + %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32> + return %0 : !arm_sve.vector<4xi1> +} + +func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>, + %b: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi1> { + // CHECK: arm_sve.cmpi uge, {{.*}}: !arm_sve.vector<4xi32> + %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi1> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vector_scale : index %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -139,6 +139,24 @@ llvm.return %3 : !llvm.vec } +// CHECK-LABEL: define @arm_sve_mask_genf +llvm.func @arm_sve_mask_genf(%arg0: !llvm.vec, + %arg1: !llvm.vec) + -> !llvm.vec { + // CHECK: fcmp oeq + %0 = llvm.fcmp "oeq" %arg0, %arg1 : !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_mask_geni +llvm.func @arm_sve_mask_geni(%arg0: !llvm.vec, + %arg1: !llvm.vec) + -> !llvm.vec { + // CHECK: icmp uge + %0 = llvm.icmp "uge" %arg0, %arg1 : !llvm.vec + llvm.return %0 : !llvm.vec +} + // CHECK-LABEL: define i64 @get_vector_scale() llvm.func @get_vector_scale() -> i64 { // CHECK: call i64 @llvm.vscale.i64()