diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -27,6 +27,37 @@ class EmitC_Op traits = []> : Op; +// Base class for binary arithmetic operations. +class EmitC_BinaryArithOp traits = []> : + EmitC_Op { + let arguments = (ins AnyType:$lhs, AnyType:$rhs); + let results = (outs AnyType); + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; + + let hasVerifier = 1; +} + +def EmitC_AddOp : EmitC_BinaryArithOp<"add", []> { + let summary = "Addition operation"; + let description = [{ + With the `add` operation the arithmetic operator + (addition) can + be applied. + + Example: + + ```mlir + // Custom form of the addition operation. + %0 = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.add %arg2, %arg3 : (!emitc.ptr, i32) -> !emitc.ptr + ``` + ```c++ + // Code emitted for the operations above. + int32_t v5 = v1 + v2; + float* v6 = v3 + v4; + ``` + }]; +} + def EmitC_ApplyOp : EmitC_Op<"apply", []> { let summary = "Apply operation"; let description = [{ @@ -175,6 +206,30 @@ let hasCustomAssemblyFormat = 1; } +def EmitC_SubOp : EmitC_BinaryArithOp<"sub", []> { + let summary = "Subtraction operation"; + let description = [{ + With the `sub` operation the arithmetic operator - (subtraction) can + be applied. + + Example: + + ```mlir + // Custom form of the substraction operation. + %0 = emitc.sub %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.sub %arg2, %arg3 : (!emitc.ptr, i32) -> !emitc.ptr + %2 = emitc.sub %arg4, %arg5 : (!emitc.ptr, !emitc.ptr) + -> !emitc.opaque<"ptrdiff_t"> + ``` + ```c++ + // Code emitted for the operations above. + int32_t v7 = v1 - v2; + float* v8 = v3 - v4; + ptrdiff_t v9 = v5 - v6; + ``` + }]; +} + def EmitC_VariableOp : EmitC_Op<"variable", []> { let summary = "Variable operation"; let description = [{ diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -44,6 +44,27 @@ return builder.create(loc, type, value); } +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +LogicalResult AddOp::verify() { + Type lhsType = getLhs().getType(); + Type rhsType = getRhs().getType(); + + if (lhsType.isa() && rhsType.isa()) + return emitOpError("requires that at most one operand is a pointer"); + + if ((lhsType.isa() && + !rhsType.isa()) || + (rhsType.isa() && + !lhsType.isa())) + return emitOpError("requires that one operand is an integer or of opaque " + "type if the other is a pointer"); + + return success(); +} + //===----------------------------------------------------------------------===// // ApplyOp //===----------------------------------------------------------------------===// @@ -178,6 +199,31 @@ return success(); } +//===----------------------------------------------------------------------===// +// SubOp +//===----------------------------------------------------------------------===// + +LogicalResult SubOp::verify() { + Type lhsType = getLhs().getType(); + Type rhsType = getRhs().getType(); + Type resultType = getResult().getType(); + + if (rhsType.isa() && !lhsType.isa()) + return emitOpError("rhs can only be a pointer if lhs is a pointer"); + + if (lhsType.isa() && + !rhsType.isa()) + return emitOpError("requires that rhs is an integer, pointer or of opaque " + "type if lhs is a pointer"); + + if (lhsType.isa() && rhsType.isa() && + !resultType.isa()) + return emitOpError("requires that the result is an integer or of opaque " + "type if lhs and rhs are pointers"); + + return success(); +} + //===----------------------------------------------------------------------===// // VariableOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -246,6 +246,32 @@ return printConstantOp(emitter, operation, value); } +static LogicalResult printBinaryArithOperation(CppEmitter &emitter, + Operation *operation, + StringRef binaryArithOperator) { + raw_ostream &os = emitter.ostream(); + + if (failed(emitter.emitAssignPrefix(*operation))) + return failure(); + os << emitter.getOrCreateName(operation->getOperand(0)); + os << " " << binaryArithOperator; + os << " " << emitter.getOrCreateName(operation->getOperand(1)); + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) { + Operation *operation = addOp.getOperation(); + + return printBinaryArithOperation(emitter, operation, "+"); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) { + Operation *operation = subOp.getOperation(); + + return printBinaryArithOperation(emitter, operation, "-"); +} + static LogicalResult printOperation(CppEmitter &emitter, cf::BranchOp branchOp) { raw_ostream &os = emitter.ostream(); @@ -930,8 +956,9 @@ .Case( [&](auto op) { return printOperation(*this, op); }) // EmitC ops. - .Case( + .Case( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -118,3 +118,52 @@ %1 = emitc.cast %arg: tensor to tensor return } + +// ----- + +func.func @add_two_pointers(%arg0: !emitc.ptr, %arg1: !emitc.ptr) { + // expected-error @+1 {{'emitc.add' op requires that at most one operand is a pointer}} + %1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr, !emitc.ptr) -> !emitc.ptr + return +} + +// ----- + +func.func @add_pointer_float(%arg0: !emitc.ptr, %arg1: f32) { + // expected-error @+1 {{'emitc.add' op requires that one operand is an integer or of opaque type if the other is a pointer}} + %1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr, f32) -> !emitc.ptr + return +} + +// ----- + +func.func @add_float_pointer(%arg0: f32, %arg1: !emitc.ptr) { + // expected-error @+1 {{'emitc.add' op requires that one operand is an integer or of opaque type if the other is a pointer}} + %1 = "emitc.add" (%arg0, %arg1) : (f32, !emitc.ptr) -> !emitc.ptr + return +} + +// ----- + +func.func @sub_int_pointer(%arg0: i32, %arg1: !emitc.ptr) { + // expected-error @+1 {{'emitc.sub' op rhs can only be a pointer if lhs is a pointer}} + %1 = "emitc.sub" (%arg0, %arg1) : (i32, !emitc.ptr) -> !emitc.ptr + return +} + + +// ----- + +func.func @sub_pointer_float(%arg0: !emitc.ptr, %arg1: f32) { + // expected-error @+1 {{'emitc.sub' op requires that rhs is an integer, pointer or of opaque type if lhs is a pointer}} + %1 = "emitc.sub" (%arg0, %arg1) : (!emitc.ptr, f32) -> !emitc.ptr + return +} + +// ----- + +func.func @sub_pointer_pointer(%arg0: !emitc.ptr, %arg1: !emitc.ptr) { + // expected-error @+1 {{'emitc.sub' op requires that the result is an integer or of opaque type if lhs and rhs are pointers}} + %1 = "emitc.sub" (%arg0, %arg1) : (!emitc.ptr, !emitc.ptr) -> !emitc.ptr + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -30,3 +30,27 @@ %2 = emitc.apply "&"(%arg1) : (i32) -> !emitc.ptr return } + +func.func @add_int(%arg0: i32, %arg1: i32) { + %1 = "emitc.add" (%arg0, %arg1) : (i32, i32) -> i32 + return +} + +func.func @add_pointer(%arg0: !emitc.ptr, %arg1: i32, %arg2: !emitc.opaque<"unsigned int">) { + %1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr, i32) -> !emitc.ptr + %2 = "emitc.add" (%arg0, %arg2) : (!emitc.ptr, !emitc.opaque<"unsigned int">) -> !emitc.ptr + return +} + +func.func @sub_int(%arg0: i32, %arg1: i32) { + %1 = "emitc.sub" (%arg0, %arg1) : (i32, i32) -> i32 + return +} + +func.func @sub_pointer(%arg0: !emitc.ptr, %arg1: i32, %arg2: !emitc.opaque<"unsigned int">, %arg3: !emitc.ptr) { + %1 = "emitc.sub" (%arg0, %arg1) : (!emitc.ptr, i32) -> !emitc.ptr + %2 = "emitc.sub" (%arg0, %arg2) : (!emitc.ptr, !emitc.opaque<"unsigned int">) -> !emitc.ptr + %3 = "emitc.sub" (%arg0, %arg3) : (!emitc.ptr, !emitc.ptr) -> !emitc.opaque<"ptrdiff_t"> + %4 = "emitc.sub" (%arg0, %arg3) : (!emitc.ptr, !emitc.ptr) -> i32 + return +} diff --git a/mlir/test/Target/Cpp/arithmetic_operators.mlir b/mlir/test/Target/Cpp/arithmetic_operators.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/Cpp/arithmetic_operators.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +func.func @add_int(%arg0: i32, %arg1: i32) { + %1 = "emitc.add" (%arg0, %arg1) : (i32, i32) -> i32 + return +} +// CHECK-LABEL: void add_int +// CHECK-NEXT: int32_t [[V2:[^ ]*]] = [[V0:[^ ]*]] + [[V1:[^ ]*]] + +func.func @add_pointer(%arg0: !emitc.ptr, %arg1: i32) { + %1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr, i32) -> !emitc.ptr + return +} +// CHECK-LABEL: void add_pointer +// CHECK-NEXT: float* [[V2:[^ ]*]] = [[V0:[^ ]*]] + [[V1:[^ ]*]] + +func.func @sub_int(%arg0: i32, %arg1: i32) { + %1 = "emitc.sub" (%arg0, %arg1) : (i32, i32) -> i32 + return +} +// CHECK-LABEL: void sub_int +// CHECK-NEXT: int32_t [[V2:[^ ]*]] = [[V0:[^ ]*]] - [[V1:[^ ]*]] + +func.func @sub_pointer(%arg0: !emitc.ptr, %arg1: !emitc.ptr) { + %1 = "emitc.sub" (%arg0, %arg1) : (!emitc.ptr, !emitc.ptr) -> !emitc.opaque<"ptrdiff_t"> + return +} +// CHECK-LABEL: void sub_pointer +// CHECK-NEXT: ptrdiff_t [[V2:[^ ]*]] = [[V0:[^ ]*]] - [[V1:[^ ]*]]