diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -234,6 +234,20 @@ let hasCustomAssemblyFormat = 1; } +def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { + let summary = "Literal operation"; + let description = [{ + The `literal` operation produces an SSA value equal to some constant + specified by an attribute. + }]; + + let arguments = (ins StrAttr:$value); + let results = (outs AnyType:$result); + + let hasVerifier = 1; + let assemblyFormat = "$value attr-dict `:` type($result)"; +} + def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> { let summary = "Multiplication operation"; let description = [{ diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -199,6 +199,16 @@ return success(); } +//===----------------------------------------------------------------------===// +// LiteralOp +//===----------------------------------------------------------------------===// + +/// The literal op requires a non-empty value. +LogicalResult emitc::LiteralOp::verify() { + if (getValue().empty()) + return emitOpError() << "value must not be empty"; + return success(); +} //===----------------------------------------------------------------------===// // SubOp //===----------------------------------------------------------------------===// @@ -220,7 +230,6 @@ !resultType.isa()) return emitOpError("requires that the result is an integer or of opaque " "type if lhs and rhs are pointers"); - return success(); } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -714,7 +714,7 @@ // When generating code for an scf.for op, printing a trailing semicolon // is handled within the printOperation function. bool trailingSemicolon = - !isa(op); + !isa(op); if (failed(emitter.emitOperation( op, /*trailingSemicolon=*/trailingSemicolon))) @@ -733,6 +733,8 @@ /// Return the existing or a new name for a Value. StringRef CppEmitter::getOrCreateName(Value val) { + if (auto literal = dyn_cast_if_present(val.getDefiningOp())) + return literal.getValue(); if (!valueMapper.count(val)) valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); return *valueMapper.begin(val); @@ -987,12 +989,17 @@ // Arithmetic ops. .Case( [&](auto op) { return printOperation(*this, op); }) + .Case([&](auto op) { return success(); }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); }); if (failed(status)) return failure(); + + if (isa(op)) + return success(); + os << (trailingSemicolon ? ";\n" : "\n"); return success(); } diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir --- a/mlir/test/Target/Cpp/for.mlir +++ b/mlir/test/Target/Cpp/for.mlir @@ -82,3 +82,23 @@ // CPP-DECLTOP-NEXT: [[SE]] = [[SI]]; // CPP-DECLTOP-NEXT: [[PE]] = [[PI]]; // CPP-DECLTOP-NEXT: return; + +func.func @test_for_yield_2() { + %start = emitc.literal "0" : index + %stop = emitc.literal "10" : index + %step = emitc.literal "1" : index + + %s0 = emitc.literal "0" : i32 + %p0 = emitc.literal "M_PI" : f32 + + %result:2 = scf.for %iter = %start to %stop step %step iter_args(%si = %s0, %pi = %p0) -> (i32, f32) { + %sn = emitc.call "add"(%si, %iter) : (i32, index) -> i32 + %pn = emitc.call "mul"(%pi, %iter) : (f32, index) -> f32 + scf.yield %sn, %pn : i32, f32 + } + + return +} +// CPP-DEFAULT: void test_for_yield_2() { +// CPP-DEFAULT: float{{.*}}= M_PI +// CPP-DEFAULT: for (size_t [[IN:.*]] = 0; [[IN]] < 10; [[IN]] += 1) {