diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -33,10 +33,12 @@ let arguments = (ins AnyType:$lhs, AnyType:$rhs); let results = (outs AnyType); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; - - let hasVerifier = 1; } +// Types only used in binary arithmetic operations. +def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>; +def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>; + def EmitC_AddOp : EmitC_BinaryArithOp<"add", []> { let summary = "Addition operation"; let description = [{ @@ -56,6 +58,8 @@ float* v6 = v3 + v4; ``` }]; + + let hasVerifier = 1; } def EmitC_ApplyOp : EmitC_Op<"apply", []> { @@ -176,6 +180,30 @@ let hasVerifier = 1; } +def EmitC_DivOp : EmitC_BinaryArithOp<"div", []> { + let summary = "Division operation"; + let description = [{ + With the `div` operation the arithmetic operator / (division) can + be applied. + + Example: + + ```mlir + // Custom form of the division operation. + %0 = emitc.div %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %arg3 : (f32, f32) -> f32 + ``` + ```c++ + // Code emitted for the operations above. + int32_t v5 = v1 / v2; + float v6 = v3 / v4; + ``` + }]; + + let arguments = (ins FloatIntegerIndexOrOpaqueType, FloatIntegerIndexOrOpaqueType); + let results = (outs FloatIntegerIndexOrOpaqueType); +} + def EmitC_IncludeOp : EmitC_Op<"include", [HasParent<"ModuleOp">]> { let summary = "Include operation"; @@ -206,6 +234,52 @@ let hasCustomAssemblyFormat = 1; } +def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> { + let summary = "Multiplication operation"; + let description = [{ + With the `mul` operation the arithmetic operator * (multiplication) can + be applied. + + Example: + + ```mlir + // Custom form of the multiplication operation. + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.mul %arg2, %arg3 : (f32, f32) -> f32 + ``` + ```c++ + // Code emitted for the operations above. + int32_t v5 = v1 * v2; + float v6 = v3 * v4; + ``` + }]; + + let arguments = (ins FloatIntegerIndexOrOpaqueType, FloatIntegerIndexOrOpaqueType); + let results = (outs FloatIntegerIndexOrOpaqueType); +} + +def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> { + let summary = "Remainder operation"; + let description = [{ + With the `rem` operation the arithmetic operator % (remainder) can + be applied. + + Example: + + ```mlir + // Custom form of the remainder operation. + %0 = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v5 = v1 % v2; + ``` + }]; + + let arguments = (ins IntegerIndexOrOpaqueType, IntegerIndexOrOpaqueType); + let results = (outs IntegerIndexOrOpaqueType); +} + def EmitC_SubOp : EmitC_BinaryArithOp<"sub", []> { let summary = "Subtraction operation"; let description = [{ @@ -228,6 +302,8 @@ ptrdiff_t v9 = v5 - v6; ``` }]; + + let hasVerifier = 1; } def EmitC_VariableOp : EmitC_Op<"variable", []> { diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -266,6 +266,24 @@ return printBinaryArithOperation(emitter, operation, "+"); } +static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) { + Operation *operation = divOp.getOperation(); + + return printBinaryArithOperation(emitter, operation, "/"); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) { + Operation *operation = mulOp.getOperation(); + + return printBinaryArithOperation(emitter, operation, "*"); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) { + Operation *operation = remOp.getOperation(); + + return printBinaryArithOperation(emitter, operation, "%"); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) { Operation *operation = subOp.getOperation(); @@ -957,8 +975,8 @@ [&](auto op) { return printOperation(*this, op); }) // EmitC ops. .Case( + emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp, emitc::MulOp, + emitc::RemOp, emitc::SubOp, emitc::VariableOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -145,6 +145,38 @@ // ----- +func.func @div_tensor(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{'emitc.div' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor'}} + %1 = "emitc.div" (%arg0, %arg1) : (tensor, tensor) -> tensor + return +} + +// ----- + +func.func @mul_tensor(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor'}} + %1 = "emitc.mul" (%arg0, %arg1) : (tensor, tensor) -> tensor + return +} + +// ----- + +func.func @rem_tensor(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'tensor'}} + %1 = "emitc.rem" (%arg0, %arg1) : (tensor, tensor) -> tensor + return +} + +// ----- + +func.func @rem_float(%arg0: f32, %arg1: f32) { + // expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'f32'}} + %1 = "emitc.rem" (%arg0, %arg1) : (f32, f32) -> f32 + return +} + +// ----- + func.func @sub_int_pointer(%arg0: i32, %arg1: !emitc.ptr) { // expected-error @+1 {{'emitc.sub' op rhs can only be a pointer if lhs is a pointer}} %1 = "emitc.sub" (%arg0, %arg1) : (i32, !emitc.ptr) -> !emitc.ptr diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -42,6 +42,31 @@ return } +func.func @div_int(%arg0: i32, %arg1: i32) { + %1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32 + return +} + +func.func @div_float(%arg0: f32, %arg1: f32) { + %1 = "emitc.div" (%arg0, %arg1) : (f32, f32) -> f32 + return +} + +func.func @mul_int(%arg0: i32, %arg1: i32) { + %1 = "emitc.mul" (%arg0, %arg1) : (i32, i32) -> i32 + return +} + +func.func @mul_float(%arg0: f32, %arg1: f32) { + %1 = "emitc.mul" (%arg0, %arg1) : (f32, f32) -> f32 + return +} + +func.func @rem(%arg0: i32, %arg1: i32) { + %1 = "emitc.rem" (%arg0, %arg1) : (i32, i32) -> i32 + return +} + func.func @sub_int(%arg0: i32, %arg1: i32) { %1 = "emitc.sub" (%arg0, %arg1) : (i32, i32) -> i32 return diff --git a/mlir/test/Target/Cpp/arithmetic_operators.mlir b/mlir/test/Target/Cpp/arithmetic_operators.mlir --- a/mlir/test/Target/Cpp/arithmetic_operators.mlir +++ b/mlir/test/Target/Cpp/arithmetic_operators.mlir @@ -14,6 +14,27 @@ // CHECK-LABEL: void add_pointer // CHECK-NEXT: float* [[V2:[^ ]*]] = [[V0:[^ ]*]] + [[V1:[^ ]*]] +func.func @div_int(%arg0: i32, %arg1: i32) { + %1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32 + return +} +// CHECK-LABEL: void div_int +// CHECK-NEXT: int32_t [[V2:[^ ]*]] = [[V0:[^ ]*]] / [[V1:[^ ]*]] + +func.func @mul_int(%arg0: i32, %arg1: i32) { + %1 = "emitc.mul" (%arg0, %arg1) : (i32, i32) -> i32 + return +} +// CHECK-LABEL: void mul_int +// CHECK-NEXT: int32_t [[V2:[^ ]*]] = [[V0:[^ ]*]] * [[V1:[^ ]*]] + +func.func @rem(%arg0: i32, %arg1: i32) { + %1 = "emitc.rem" (%arg0, %arg1) : (i32, i32) -> i32 + return +} +// CHECK-LABEL: void rem +// CHECK-NEXT: int32_t [[V2:[^ ]*]] = [[V0:[^ ]*]] % [[V1:[^ ]*]] + func.func @sub_int(%arg0: i32, %arg1: i32) { %1 = "emitc.sub" (%arg0, %arg1) : (i32, i32) -> i32 return