diff --git a/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt @@ -2,6 +2,8 @@ add_mlir_doc(EmitC EmitC Dialects/ -gen-dialect-doc) set(LLVM_TARGET_DEFINITIONS EmitCAttributes.td) +mlir_tablegen(EmitCEnums.h.inc -gen-enum-decls) +mlir_tablegen(EmitCEnums.cpp.inc -gen-enum-defs) mlir_tablegen(EmitCAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(EmitCAttributes.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(MLIREmitCAttributesIncGen) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -21,6 +21,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc" +#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc" #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/EmitC/IR/EmitCAttributes.h.inc" diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -27,8 +27,8 @@ class EmitC_Op traits = []> : Op; -// Base class for binary arithmetic operations. -class EmitC_BinaryArithOp traits = []> : +// Base class for binary operations. +class EmitC_BinaryOp traits = []> : EmitC_Op { let arguments = (ins AnyType:$lhs, AnyType:$rhs); let results = (outs AnyType); @@ -39,7 +39,7 @@ def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>; def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>; -def EmitC_AddOp : EmitC_BinaryArithOp<"add", []> { +def EmitC_AddOp : EmitC_BinaryOp<"add", []> { let summary = "Addition operation"; let description = [{ With the `add` operation the arithmetic operator + (addition) can @@ -150,6 +150,37 @@ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; } +def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> { + let summary = "Comparison operation"; + let description = [{ + With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=> + can be applied. + + Example: + ```mlir + // Custom form of the cmp operation. + %0 = emitc.cmp eq, %arg0, %arg1 : (i32, i32) -> i1 + %1 = emitc.cmp lt, %arg2, %arg3 : + ( + !emitc.opaque<"std::valarray">, + !emitc.opaque<"std::valarray"> + ) -> !emitc.opaque<"std::valarray"> + ``` + ```c++ + // Code emitted for the operations above. + bool v5 = v1 == v2; + std::valarray v6 = v3 < v4; + ``` + }]; + + let arguments = (ins EmitC_CmpPredicateAttr:$predicate, + AnyType:$lhs, + AnyType:$rhs); + let results = (outs AnyType); + + let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)"; +} + def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { let summary = "Constant operation"; let description = [{ @@ -180,7 +211,7 @@ let hasVerifier = 1; } -def EmitC_DivOp : EmitC_BinaryArithOp<"div", []> { +def EmitC_DivOp : EmitC_BinaryOp<"div", []> { let summary = "Division operation"; let description = [{ With the `div` operation the arithmetic operator / (division) can @@ -248,7 +279,7 @@ let assemblyFormat = "$value attr-dict `:` type($result)"; } -def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> { +def EmitC_MulOp : EmitC_BinaryOp<"mul", []> { let summary = "Multiplication operation"; let description = [{ With the `mul` operation the arithmetic operator * (multiplication) can @@ -272,7 +303,7 @@ let results = (outs FloatIntegerIndexOrOpaqueType); } -def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> { +def EmitC_RemOp : EmitC_BinaryOp<"rem", []> { let summary = "Remainder operation"; let description = [{ With the `rem` operation the arithmetic operator % (remainder) can @@ -294,7 +325,7 @@ let results = (outs IntegerIndexOrOpaqueType); } -def EmitC_SubOp : EmitC_BinaryArithOp<"sub", []> { +def EmitC_SubOp : EmitC_BinaryOp<"sub", []> { let summary = "Subtraction operation"; let description = [{ With the `sub` operation the arithmetic operator - (subtraction) can diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td @@ -15,6 +15,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/EnumAttr.td" include "mlir/Dialect/EmitC/IR/EmitCBase.td" //===----------------------------------------------------------------------===// @@ -26,6 +27,20 @@ let mnemonic = attrMnemonic; } +def EmitC_CmpPredicateAttr : I64EnumAttr< + "CmpPredicate", "", + [ + I64EnumAttrCase<"eq", 0>, + I64EnumAttrCase<"ne", 1>, + I64EnumAttrCase<"lt", 2>, + I64EnumAttrCase<"le", 3>, + I64EnumAttrCase<"gt", 4>, + I64EnumAttrCase<"ge", 5>, + I64EnumAttrCase<"three_way", 6>, + ]> { + let cppNamespace = "::mlir::emitc"; +} + def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> { let summary = "An opaque attribute"; diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -257,6 +257,13 @@ #define GET_OP_CLASSES #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc" + +//===----------------------------------------------------------------------===// +// EmitC Enums +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc" + //===----------------------------------------------------------------------===// // EmitC Attributes //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -246,15 +246,15 @@ return printConstantOp(emitter, operation, value); } -static LogicalResult printBinaryArithOperation(CppEmitter &emitter, - Operation *operation, - StringRef binaryArithOperator) { +static LogicalResult printBinaryOperation(CppEmitter &emitter, + Operation *operation, + StringRef binaryOperator) { raw_ostream &os = emitter.ostream(); if (failed(emitter.emitAssignPrefix(*operation))) return failure(); os << emitter.getOrCreateName(operation->getOperand(0)); - os << " " << binaryArithOperator; + os << " " << binaryOperator; os << " " << emitter.getOrCreateName(operation->getOperand(1)); return success(); @@ -263,31 +263,65 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) { Operation *operation = addOp.getOperation(); - return printBinaryArithOperation(emitter, operation, "+"); + return printBinaryOperation(emitter, operation, "+"); } static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) { Operation *operation = divOp.getOperation(); - return printBinaryArithOperation(emitter, operation, "/"); + return printBinaryOperation(emitter, operation, "/"); } static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) { Operation *operation = mulOp.getOperation(); - return printBinaryArithOperation(emitter, operation, "*"); + return printBinaryOperation(emitter, operation, "*"); } static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) { Operation *operation = remOp.getOperation(); - return printBinaryArithOperation(emitter, operation, "%"); + return printBinaryOperation(emitter, operation, "%"); } static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) { Operation *operation = subOp.getOperation(); - return printBinaryArithOperation(emitter, operation, "-"); + return printBinaryOperation(emitter, operation, "-"); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) { + Operation *operation = cmpOp.getOperation(); + + StringRef binaryOperator; + + switch (cmpOp.getPredicate()) { + case emitc::CmpPredicate::eq: + binaryOperator = "=="; + break; + case emitc::CmpPredicate::ne: + binaryOperator = "!="; + break; + case emitc::CmpPredicate::lt: + binaryOperator = "<"; + break; + case emitc::CmpPredicate::le: + binaryOperator = "<="; + break; + case emitc::CmpPredicate::gt: + binaryOperator = ">"; + break; + case emitc::CmpPredicate::ge: + binaryOperator = ">="; + break; + case emitc::CmpPredicate::three_way: + binaryOperator = "<=>"; + break; + default: + return cmpOp.emitError("unhandled comparison predicate"); + } + + return printBinaryOperation(emitter, operation, binaryOperator); } static LogicalResult printOperation(CppEmitter &emitter, @@ -977,8 +1011,8 @@ [&](auto op) { return printOperation(*this, op); }) // EmitC ops. .Case( + emitc::CmpOp, emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp, + emitc::MulOp, emitc::RemOp, emitc::SubOp, emitc::VariableOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -79,3 +79,21 @@ %4 = "emitc.sub" (%arg0, %arg3) : (!emitc.ptr, !emitc.ptr) -> i32 return } + +func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray">, %arg6 : !emitc.opaque<"custom">) { + %1 = "emitc.cmp" (%arg0, %arg0) {predicate = 0} : (i32, i32) -> i1 + %2 = emitc.cmp eq, %arg0, %arg0 : (i32, i32) -> i1 + %3 = "emitc.cmp" (%arg1, %arg1) {predicate = 1} : (f32, f32) -> i1 + %4 = emitc.cmp ne, %arg1, %arg1 : (f32, f32) -> i1 + %5 = "emitc.cmp" (%arg2, %arg2) {predicate = 2} : (i64, i64) -> i1 + %6 = emitc.cmp lt, %arg2, %arg2 : (i64, i64) -> i1 + %7 = "emitc.cmp" (%arg3, %arg3) {predicate = 3} : (f64, f64) -> i1 + %8 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1 + %9 = "emitc.cmp" (%arg4, %arg4) {predicate = 4} : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1 + %10 = emitc.cmp gt, %arg4, %arg4 : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1 + %11 = "emitc.cmp" (%arg5, %arg5) {predicate = 5} : (!emitc.opaque<"std::valarray">, !emitc.opaque<"std::valarray">) -> !emitc.opaque<"std::valarray"> + %12 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray">, !emitc.opaque<"std::valarray">) -> !emitc.opaque<"std::valarray"> + %13 = "emitc.cmp" (%arg6, %arg6) {predicate = 6} : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom"> + %14 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom"> + return +} diff --git a/mlir/test/Target/Cpp/comparison_operators.mlir b/mlir/test/Target/Cpp/comparison_operators.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/Cpp/comparison_operators.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray">, %arg6 : !emitc.opaque<"custom">) { + %1 = emitc.cmp eq, %arg0, %arg0 : (i32, i32) -> i1 + %2 = emitc.cmp ne, %arg1, %arg1 : (f32, f32) -> i1 + %3 = emitc.cmp lt, %arg2, %arg2 : (i64, i64) -> i1 + %4 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1 + %5 = emitc.cmp gt, %arg4, %arg4 : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1 + %6 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray">, !emitc.opaque<"std::valarray">) -> !emitc.opaque<"std::valarray"> + %7 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom"> + + return +} +// CHECK-LABEL: void cmp +// CHECK-NEXT: bool [[V7:[^ ]*]] = [[V0:[^ ]*]] == [[V0:[^ ]*]] +// CHECK-NEXT: bool [[V8:[^ ]*]] = [[V1:[^ ]*]] != [[V1:[^ ]*]] +// CHECK-NEXT: bool [[V9:[^ ]*]] = [[V2:[^ ]*]] < [[V2:[^ ]*]] +// CHECK-NEXT: bool [[V10:[^ ]*]] = [[V3:[^ ]*]] <= [[V3:[^ ]*]] +// CHECK-NEXT: bool [[V11:[^ ]*]] = [[V4:[^ ]*]] > [[V4:[^ ]*]] +// CHECK-NEXT: std::valarray [[V12:[^ ]*]] = [[V5:[^ ]*]] >= [[V5:[^ ]*]] +// CHECK-NEXT: custom [[V13:[^ ]*]] = [[V6:[^ ]*]] <=> [[V6:[^ ]*]]