diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -16,6 +16,7 @@ include "mlir/Dialect/EmitC/IR/EmitCAttributes.td" include "mlir/Dialect/EmitC/IR/EmitCTypes.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// @@ -88,6 +89,33 @@ let hasVerifier = 1; } +def EmitC_CastOp : EmitC_Op<"cast", [ + DeclareOpInterfaceMethods, + SameOperandsAndResultShape + ]> { + let summary = "Cast operation"; + let description = [{ + The `cast` operation performs an explicit type conversion and is emitted + as a C-style cast expression. It can be applied to integer, float, index + and EmitC types. + + Example: + + ```mlir + // Cast from `int32_t` to `float` + %0 = emitc.cast %arg0: i32 to f32 + + // Cast from `void` to `int32_t` pointer + %1 = emitc.cast %arg1 : + !emitc.ptr> to !emitc.ptr + ``` + }]; + + let arguments = (ins AnyType:$source); + let results = (outs AnyType:$dest); + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; +} + def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { let summary = "Constant operation"; let description = [{ diff --git a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt --- a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt @@ -9,6 +9,7 @@ MLIREmitCAttributesIncGen LINK_LIBS PUBLIC + MLIRCastInterfaces MLIRIR MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -62,6 +62,25 @@ return success(); } +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + + Type input = inputs.front(), output = outputs.front(); + + if ((!input.isa()) || + (!output.isa())) + return false; + + return true; +} + //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -382,6 +382,21 @@ return success(); } +static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { + raw_ostream &os = emitter.ostream(); + Operation &op = *castOp.getOperation(); + + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + os << "("; + if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType()))) + return failure(); + os << ") "; + os << emitter.getOrCreateName(castOp.getOperand()); + + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::IncludeOp includeOp) { raw_ostream &os = emitter.ostream(); @@ -917,7 +932,7 @@ .Case( [&](auto op) { return printOperation(*this, op); }) // EmitC ops. - .Case( [&](auto op) { return printOperation(*this, op); }) // Func ops. diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -93,3 +93,11 @@ %c0 = "emitc.variable"(){value = "nullptr" : !emitc.ptr} : () -> !emitc.ptr return } + +// ----- + +func @cast_tensor(%arg : tensor) { + // expected-error @+1 {{'emitc.cast' op operand type 'tensor' and result type 'tensor' are cast incompatible}} + %1 = emitc.cast %arg: tensor to tensor + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -12,6 +12,11 @@ return } +func @cast(%arg0: i32) { + %1 = emitc.cast %arg0: i32 to f32 + return +} + func @c() { %1 = "emitc.constant"(){value = 42 : i32} : () -> i32 return diff --git a/mlir/test/Target/Cpp/cast.mlir b/mlir/test/Target/Cpp/cast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/Cpp/cast.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s +// CHECK-LABEL: void cast +func @cast(%arg0 : i32) { + // CHECK-NEXT: uint32_t [[V1:[^ ]*]] = (uint32_t) [[V0:[^ ]*]] + %1 = emitc.cast %arg0: i32 to ui32 + + // CHECK-NEXT: int64_t [[V4:[^ ]*]] = (int64_t) [[V0:[^ ]*]] + %2 = emitc.cast %arg0: i32 to i64 + // CHECK-NEXT: int64_t [[V5:[^ ]*]] = (uint64_t) [[V0:[^ ]*]] + %3 = emitc.cast %arg0: i32 to ui64 + + // CHECK-NEXT: float [[V4:[^ ]*]] = (float) [[V0:[^ ]*]] + %4 = emitc.cast %arg0: i32 to f32 + // CHECK-NEXT: double [[V5:[^ ]*]] = (double) [[V0:[^ ]*]] + %5 = emitc.cast %arg0: i32 to f64 + + // CHECK-NEXT: bool [[V6:[^ ]*]] = (bool) [[V0:[^ ]*]] + %6 = emitc.cast %arg0: i32 to i1 + + // CHECK-NEXT: mytype [[V7:[^ ]*]] = (mytype) [[V0:[^ ]*]] + %7 = emitc.cast %arg0: i32 to !emitc.opaque<"mytype"> + return +} + +// CHECK-LABEL: void cast_ptr +func @cast_ptr(%arg0 : !emitc.ptr>) { + // CHECK-NEXT: int32_t* [[V1:[^ ]*]] = (int32_t*) [[V0:[^ ]*]] + %1 = emitc.cast %arg0 : !emitc.ptr> to !emitc.ptr + return +}