diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -608,7 +608,10 @@ let autogenSerialization = 0; } -def SPV_YieldOp : SPV_Op<"mlir.yield", [NoSideEffect, Terminator]> { +def SPV_YieldOp : SPV_Op<"mlir.yield", + [NoSideEffect, + Terminator, + HasParent<"SpecConstantOperationOp">]> { let summary = "Yields the result computed in `spv.SpecConstantOperation`'s" "region back to the parent op."; @@ -639,11 +642,14 @@ let autogenSerialization = 0; let assemblyFormat = "attr-dict $operand `:` type($operand)"; + + let verifier = ?; } -def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [ - InFunctionScope, NoSideEffect, - IsolatedFromAbove]> { +def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", + [InFunctionScope, + NoSideEffect, + SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Declare a new specialization constant that results from doing an operation."; let description = [{ @@ -653,12 +659,12 @@ In the `spv` dialect, this op is modelled as follows: ``` - spv-spec-constant-operation-op ::= `"spv.SpecConstantOperation"` - `(`ssa-id (`, ` ssa-id)`)` - `({` - ssa-id = spirv-op - `spv.mlir.yield` ssa-id - `})` `:` function-type + spv-spec-constant-operation-op ::= `spv.SpecConstantOperation` + `(`ssa-id=ssa-id (`, ` ssa-id=ssa-id)`)` + `:` function-type `{` + ssa-id = spirv-op + `spv.mlir.yield` ssa-id + `}` ``` In particular, an `spv.SpecConstantOperation` contains exactly one @@ -713,16 +719,16 @@ ```mlir %0 = spv.constant 1: i32 - %1 = "spv.SpecConstantOperation"(%0) ({ - %ret = spv.IAdd %0, %0 : i32 + %1 = spv.SpecConstantOperation(%lhs=%0, %rhs=%0) { + %ret = spv.IAdd %lhs, %rhs : i32 spv.mlir.yield %ret : i32 - }) : (i32) -> i32 + } : (i32, i32) -> i32 ``` }]; - let arguments = (ins Variadic:$operands); + let arguments = (ins); - let results = (outs AnyType:$results); + let results = (outs AnyType:$result); let regions = (region SizedRegion<1>:$body); diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -147,6 +147,25 @@ /// special or non-printable characters in it. virtual void printSymbolName(StringRef symbolRef) = 0; + /// Prints the initialization list in the form of + /// (%inner = %outer, %inner2 = %outer2, <...>) + /// where 'inner' values are assumed to be region arguments and 'outer' values + /// are regular SSA values. + void printInitializationList(Block::BlockArgListType blocksArgs, + ValueRange initializers, StringRef prefix = "") { + assert(blocksArgs.size() == initializers.size() && + "expected same length of arguments and initializers"); + if (initializers.empty()) + return; + + auto &os = getStream(); + os << prefix << '('; + llvm::interleaveComma( + llvm::zip(blocksArgs, initializers), *this, + [&](auto it) { *this << std::get<0>(it) << " = " << std::get<1>(it); }); + os << ")"; + } + private: OpAsmPrinter(const OpAsmPrinter &) = delete; void operator=(const OpAsmPrinter &) = delete; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -141,32 +141,12 @@ return RegionBranchOpInterface::verifyTypes(op); } -/// Prints the initialization list in the form of -/// (%inner = %outer, %inner2 = %outer2, <...>) -/// where 'inner' values are assumed to be region arguments and 'outer' values -/// are regular SSA values. -static void printInitializationList(OpAsmPrinter &p, - Block::BlockArgListType blocksArgs, - ValueRange initializers, - StringRef prefix = "") { - assert(blocksArgs.size() == initializers.size() && - "expected same length of arguments and initializers"); - if (initializers.empty()) - return; - - p << prefix << '('; - llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { - p << std::get<0>(it) << " = " << std::get<1>(it); - }); - p << ")"; -} - static void print(OpAsmPrinter &p, ForOp op) { p << op.getOperationName() << " " << op.getInductionVar() << " = " << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); - printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(), - " iter_args"); + p.printInitializationList(op.getRegionIterArgs(), op.getIterOperands(), + " iter_args"); if (!op.getIterOperands().empty()) p << " -> (" << op.getIterOperands().getTypes() << ')'; p.printRegion(op.region(), @@ -1257,8 +1237,8 @@ /// Prints a `while` op. static void print(OpAsmPrinter &p, scf::WhileOp op) { p << op.getOperationName(); - printInitializationList(p, op.before().front().getArguments(), op.inits(), - " "); + p.printInitializationList(op.before().front().getArguments(), op.inits(), + " "); p << " : "; p.printFunctionalType(op.inits().getTypes(), op.results().getTypes()); p.printRegion(op.before(), /*printEntryBlockArgs=*/false); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -3396,35 +3396,39 @@ } //===----------------------------------------------------------------------===// -// spv.mlir.yield +// spv.SpecConstantOperation //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::YieldOp yieldOp) { - Operation *parentOp = yieldOp->getParentOp(); +static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser, + OperationState &state) { + Region *body = state.addRegion(); - if (!parentOp || !isa(parentOp)) - return yieldOp.emitOpError( - "expected parent op to be 'spv.SpecConstantOperation'"); + if (parser.parseKeyword("wraps")) + return failure(); - Block &block = parentOp->getRegion(0).getBlocks().front(); - Operation &enclosedOp = block.getOperations().front(); + body->push_back(new Block); + Block &block = body->back(); + Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); - if (yieldOp.getOperand().getDefiningOp() != &enclosedOp) - return yieldOp.emitOpError( - "expected operand to be defined by preceeding op"); + if (!wrappedOp) + return failure(); - return success(); -} + OpBuilder builder(parser.getBuilder().getContext()); + builder.setInsertionPointToEnd(&block); + builder.create(wrappedOp->getLoc(), wrappedOp->getResult(0)); + state.location = wrappedOp->getLoc(); -static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser, - OperationState &state) { - // TODO: For now, only generic form is supported. - return failure(); + state.addTypes(wrappedOp->getResult(0).getType()); + + if (parser.parseOptionalAttrDict(state.attributes)) + return failure(); + + return success(); } static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) { - // TODO - printer.printGenericOp(op); + printer << op.getOperationName() << " wraps "; + printer.printGenericOp(&op.body().front().front()); } static LogicalResult verify(spirv::SpecConstantOperationOp constOp) { @@ -3433,11 +3437,6 @@ if (block.getOperations().size() != 2) return constOp.emitOpError("expected exactly 2 nested ops"); - Operation &yieldOp = block.getOperations().back(); - - if (!isa(yieldOp)) - return constOp.emitOpError("expected terminator to be a yield op"); - Operation &enclosedOp = block.getOperations().front(); // TODO Add a `UsableInSpecConstantOp` trait and mark ops from the list below @@ -3457,21 +3456,12 @@ spirv::UGreaterThanEqualOp, spirv::SGreaterThanEqualOp>(enclosedOp)) return constOp.emitOpError("invalid enclosed op"); - if (enclosedOp.getNumOperands() != constOp.getOperands().size()) - return constOp.emitOpError("invalid number of operands; expected ") - << enclosedOp.getNumOperands() << ", actual " - << constOp.getOperands().size(); - - if (enclosedOp.getNumOperands() != constOp.getRegion().getNumArguments()) - return constOp.emitOpError("invalid number of region arguments; expected ") - << enclosedOp.getNumOperands() << ", actual " - << constOp.getRegion().getNumArguments(); - - for (auto operand : constOp.getOperands()) + for (auto operand : enclosedOp.getOperands()) if (!isa( operand.getDefiningOp())) - return constOp.emitOpError("invalid operand"); + return constOp.emitOpError( + "invalid operand, must be defined by a constant operation"); return success(); } diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -757,6 +757,7 @@ // expected-error @+1 {{unsupported composite type}} spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device> } + //===----------------------------------------------------------------------===// // spv.SpecConstantOperation //===----------------------------------------------------------------------===// @@ -765,34 +766,15 @@ spv.module Logical GLSL450 { spv.func @foo() -> i32 "None" { + // CHECK: [[LHS:%.*]] = spv.constant %0 = spv.constant 1: i32 - %2 = spv.constant 1: i32 - - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - %2 = spv.constant 1: i32 + // CHECK: [[RHS:%.*]] = spv.constant + %1 = spv.constant 1: i32 - // expected-error @+1 {{invalid number of operands; expected 2, actual 1}} - %1 = "spv.SpecConstantOperation"(%0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32) -> i32 + // CHECK: spv.SpecConstantOperation wraps "spv.IAdd"([[LHS]], [[RHS]]) : (i32, i32) -> i32 + %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32 - spv.ReturnValue %1 : i32 + spv.ReturnValue %2 : i32 } } @@ -801,93 +783,20 @@ spv.module Logical GLSL450 { spv.func @foo() -> i32 "None" { %0 = spv.constant 1: i32 - %2 = spv.constant 1: i32 - - // expected-error @+1 {{invalid number of region arguments; expected 2, actual 1}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32): - %ret = spv.IAdd %lhs, %lhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - // expected-error @+1 {{expected parent op to be 'spv.SpecConstantOperation'}} + // expected-error @+1 {{op expects parent op 'spv.SpecConstantOperation'}} spv.mlir.yield %0 : i32 } } // ----- -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.ISub %lhs, %rhs : i32 - // expected-error @+1 {{expected operand to be defined by preceeding op}} - spv.mlir.yield %lhs : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - - // expected-error @+1 {{expected exactly 2 nested ops}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - %ret2 = spv.IAdd %lhs, %rhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - - // expected-error @+1 {{expected terminator to be a yield op}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - spv.ReturnValue %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - spv.module Logical GLSL450 { spv.func @foo() -> () "None" { %0 = spv.Variable : !spv.ptr // expected-error @+1 {{invalid enclosed op}} - %2 = "spv.SpecConstantOperation"(%0) ({ - ^bb(%arg0 : !spv.ptr): - %ret = spv.Load "Function" %arg0 : i32 - spv.mlir.yield %ret : i32 - }) : (!spv.ptr) -> i32 + %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr) -> i32 + spv.Return } } @@ -898,11 +807,9 @@ %0 = spv.Variable : !spv.ptr %1 = spv.Load "Function" %0 : i32 - // expected-error @+1 {{invalid operand}} - %2 = "spv.SpecConstantOperation"(%1, %1) ({ - ^bb(%lhs: i32, %rhs: i32): - %ret = spv.IAdd %lhs, %lhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 + // expected-error @+1 {{invalid operand, must be defined by a constant operation}} + %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%1, %1) : (i32, i32) -> i32 + + spv.Return } }