diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -653,12 +653,12 @@ In the `spv` dialect, this op is modelled as follows: ``` - spv-spec-constant-operation-op ::= `"spv.SpecConstantOperation"` - `(`ssa-id (`, ` ssa-id)`)` - `({` - ssa-id = spirv-op - `spv.mlir.yield` ssa-id - `})` `:` function-type + spv-spec-constant-operation-op ::= `spv.SpecConstantOperation` + `(`ssa-id=ssa-id (`, ` ssa-id=ssa-id)`)` + `:` function-type `{` + ssa-id = spirv-op + `spv.mlir.yield` ssa-id + `}` ``` In particular, an `spv.SpecConstantOperation` contains exactly one @@ -713,16 +713,16 @@ ```mlir %0 = spv.constant 1: i32 - %1 = "spv.SpecConstantOperation"(%0) ({ - %ret = spv.IAdd %0, %0 : i32 + %1 = spv.SpecConstantOperation(%lhs=%0, %rhs=%0) { + %ret = spv.IAdd %lhs, %rhs : i32 spv.mlir.yield %ret : i32 - }) : (i32) -> i32 + } : (i32, i32) -> i32 ``` }]; let arguments = (ins Variadic:$operands); - let results = (outs AnyType:$results); + let results = (outs AnyType:$result); let regions = (region SizedRegion<1>:$body); diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -143,6 +143,25 @@ /// special or non-printable characters in it. virtual void printSymbolName(StringRef symbolRef) = 0; + /// Prints the initialization list in the form of + /// (%inner = %outer, %inner2 = %outer2, <...>) + /// where 'inner' values are assumed to be region arguments and 'outer' values + /// are regular SSA values. + void printInitializationList(Block::BlockArgListType blocksArgs, + ValueRange initializers, StringRef prefix = "") { + assert(blocksArgs.size() == initializers.size() && + "expected same length of arguments and initializers"); + if (initializers.empty()) + return; + + auto &os = getStream(); + os << prefix << '('; + llvm::interleaveComma( + llvm::zip(blocksArgs, initializers), *this, + [&](auto it) { *this << std::get<0>(it) << " = " << std::get<1>(it); }); + os << ")"; + } + private: OpAsmPrinter(const OpAsmPrinter &) = delete; void operator=(const OpAsmPrinter &) = delete; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -141,32 +141,12 @@ return RegionBranchOpInterface::verifyTypes(op); } -/// Prints the initialization list in the form of -/// (%inner = %outer, %inner2 = %outer2, <...>) -/// where 'inner' values are assumed to be region arguments and 'outer' values -/// are regular SSA values. -static void printInitializationList(OpAsmPrinter &p, - Block::BlockArgListType blocksArgs, - ValueRange initializers, - StringRef prefix = "") { - assert(blocksArgs.size() == initializers.size() && - "expected same length of arguments and initializers"); - if (initializers.empty()) - return; - - p << prefix << '('; - llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { - p << std::get<0>(it) << " = " << std::get<1>(it); - }); - p << ")"; -} - static void print(OpAsmPrinter &p, ForOp op) { p << op.getOperationName() << " " << op.getInductionVar() << " = " << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); - printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(), - " iter_args"); + p.printInitializationList(op.getRegionIterArgs(), op.getIterOperands(), + " iter_args"); if (!op.getIterOperands().empty()) p << " -> (" << op.getIterOperands().getTypes() << ')'; p.printRegion(op.region(), @@ -1257,8 +1237,8 @@ /// Prints a `while` op. static void print(OpAsmPrinter &p, scf::WhileOp op) { p << op.getOperationName(); - printInitializationList(p, op.before().front().getArguments(), op.inits(), - " "); + p.printInitializationList(op.before().front().getArguments(), op.inits(), + " "); p << " : "; p.printFunctionalType(op.inits().getTypes(), op.results().getTypes()); p.printRegion(op.before(), /*printEntryBlockArgs=*/false); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -3416,15 +3416,49 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.SpecConstantOperation +//===----------------------------------------------------------------------===// + static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser, OperationState &state) { - // TODO: For now, only generic form is supported. - return failure(); + SmallVector regionArgs, operands; + Region *body = state.addRegion(); + + if (parser.parseAssignmentList(regionArgs, operands)) + return failure(); + + FunctionType functionType; + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + if (failed(parser.parseColonType(functionType))) + return failure(); + + state.addTypes(functionType.getResults()); + + if (functionType.getNumInputs() != operands.size()) + return parser.emitError(typeLoc) + << "expected as many input types as operands " + << "(expected " << operands.size() << " got " + << functionType.getNumInputs() << ")"; + + // Resolve input operands. + if (parser.resolveOperands(operands, functionType.getInputs(), + parser.getCurrentLocation(), state.operands)) + return failure(); + + if (parser.parseRegion(*body, regionArgs, functionType.getInputs())) + return failure(); + + return success(); } static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) { - // TODO - printer.printGenericOp(op); + printer << op.getOperationName(); + printer.printInitializationList(op.body().front().getArguments(), + op.operands()); + printer << " : "; + printer.printFunctionalType(op); + printer.printRegion(op.body(), false); } static LogicalResult verify(spirv::SpecConstantOperationOp constOp) { @@ -3471,7 +3505,8 @@ if (!isa( operand.getDefiningOp())) - return constOp.emitOpError("invalid operand"); + return constOp.emitOpError( + "invalid operand, must be defined by a constant operation"); return success(); } diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -757,6 +757,7 @@ // expected-error @+1 {{unsupported composite type}} spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device> } + //===----------------------------------------------------------------------===// // spv.SpecConstantOperation //===----------------------------------------------------------------------===// @@ -766,15 +767,35 @@ spv.module Logical GLSL450 { spv.func @foo() -> i32 "None" { %0 = spv.constant 1: i32 - %2 = spv.constant 1: i32 + %1 = spv.constant 1: i32 - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): + // CHECK: spv.SpecConstantOperation([[LHS:%.*]] = {{%.*}}, [[RHS:%.*]] = {{%.*}}) : (i32, i32) -> i32 { + // CHECK-NEXT: [[RET:%.*]] = spv.IAdd [[LHS]], [[RHS]] : i32 + // CHECK-NEXT: spv.mlir.yield [[RET]] : i32 + // CHECK-NEXT: } + %2 = spv.SpecConstantOperation(%lhs = %0, %rhs = %1) : (i32, i32) -> i32 { %ret = spv.IAdd %lhs, %rhs : i32 spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 + } - spv.ReturnValue %1 : i32 + spv.ReturnValue %2 : i32 + } +} + +// ----- + +spv.module Logical GLSL450 { + spv.func @foo() -> i32 "None" { + %0 = spv.constant 1: i32 + %1 = spv.constant 1: i32 + + // expected-error @+1 {{expected as many input types as operands (expected 2 got 1)}} + %2 = spv.SpecConstantOperation(%lhs = %0, %rhs = %1) : (i32) -> i32 { + %ret = spv.IAdd %lhs, %rhs : i32 + spv.mlir.yield %ret : i32 + } + + spv.ReturnValue %2 : i32 } } @@ -830,12 +851,11 @@ spv.func @foo() -> i32 "None" { %0 = spv.constant 1: i32 - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): + %1 = spv.SpecConstantOperation(%lhs=%0, %rhs=%0) : (i32, i32) -> i32 { %ret = spv.ISub %lhs, %rhs : i32 // expected-error @+1 {{expected operand to be defined by preceeding op}} spv.mlir.yield %lhs : i32 - }) : (i32, i32) -> i32 + } spv.ReturnValue %1 : i32 } @@ -848,12 +868,11 @@ %0 = spv.constant 1: i32 // expected-error @+1 {{expected exactly 2 nested ops}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): + %1 = spv.SpecConstantOperation(%lhs=%0, %rhs=%0) : (i32, i32) -> i32 { %ret = spv.IAdd %lhs, %rhs : i32 %ret2 = spv.IAdd %lhs, %rhs : i32 spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 + } spv.ReturnValue %1 : i32 } @@ -866,11 +885,10 @@ %0 = spv.constant 1: i32 // expected-error @+1 {{expected terminator to be a yield op}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): + %1 = spv.SpecConstantOperation(%lhs=%0, %rhs=%0) : (i32, i32) -> i32 { %ret = spv.IAdd %lhs, %rhs : i32 spv.ReturnValue %ret : i32 - }) : (i32, i32) -> i32 + } spv.ReturnValue %1 : i32 } @@ -883,11 +901,10 @@ %0 = spv.Variable : !spv.ptr // expected-error @+1 {{invalid enclosed op}} - %2 = "spv.SpecConstantOperation"(%0) ({ - ^bb(%arg0 : !spv.ptr): + %2 = spv.SpecConstantOperation(%arg0 = %0) : (!spv.ptr) -> i32 { %ret = spv.Load "Function" %arg0 : i32 spv.mlir.yield %ret : i32 - }) : (!spv.ptr) -> i32 + } } } @@ -898,11 +915,10 @@ %0 = spv.Variable : !spv.ptr %1 = spv.Load "Function" %0 : i32 - // expected-error @+1 {{invalid operand}} - %2 = "spv.SpecConstantOperation"(%1, %1) ({ - ^bb(%lhs: i32, %rhs: i32): + // expected-error @+1 {{invalid operand, must be defined by a constant operation}} + %2 = spv.SpecConstantOperation(%lhs=%1, %rhs=%1) : (i32, i32) -> i32 { %ret = spv.IAdd %lhs, %lhs : i32 spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 + } } }