diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -608,9 +608,12 @@ let autogenSerialization = 0; } -def SPV_YieldOp : SPV_Op<"mlir.yield", [NoSideEffect, Terminator]> { - let summary = "Yields the result computed in `spv.SpecConstantOperation`'s" - "region back to the parent op."; +def SPV_YieldOp : SPV_Op<"mlir.yield", [ + HasParent<"SpecConstantOperationOp">, NoSideEffect, Terminator]> { + let summary = [{ + Yields the result computed in `spv.SpecConstantOperation`'s + region back to the parent op. + }]; let description = [{ This op is a special terminator whose only purpose is to terminate @@ -639,12 +642,16 @@ let autogenSerialization = 0; let assemblyFormat = "attr-dict $operand `:` type($operand)"; + + let verifier = [{ return success(); }]; } def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [ - InFunctionScope, NoSideEffect, - IsolatedFromAbove]> { - let summary = "Declare a new specialization constant that results from doing an operation."; + NoSideEffect, InFunctionScope, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = [{ + Declare a new specialization constant that results from doing an operation. + }]; let description = [{ This op declares a SPIR-V specialization constant that results from @@ -653,12 +660,8 @@ In the `spv` dialect, this op is modelled as follows: ``` - spv-spec-constant-operation-op ::= `"spv.SpecConstantOperation"` - `(`ssa-id (`, ` ssa-id)`)` - `({` - ssa-id = spirv-op - `spv.mlir.yield` ssa-id - `})` `:` function-type + spv-spec-constant-operation-op ::= `spv.SpecConstantOperation` `wraps` + generic-spirv-op `:` function-type ``` In particular, an `spv.SpecConstantOperation` contains exactly one @@ -712,17 +715,15 @@ #### Example: ```mlir %0 = spv.constant 1: i32 + %1 = spv.constant 1: i32 - %1 = "spv.SpecConstantOperation"(%0) ({ - %ret = spv.IAdd %0, %0 : i32 - spv.mlir.yield %ret : i32 - }) : (i32) -> i32 + %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32 ``` }]; - let arguments = (ins Variadic:$operands); + let arguments = (ins); - let results = (outs AnyType:$results); + let results = (outs AnyType:$result); let regions = (region SizedRegion<1>:$body); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -3396,35 +3396,39 @@ } //===----------------------------------------------------------------------===// -// spv.mlir.yield +// spv.SpecConstantOperation //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::YieldOp yieldOp) { - Operation *parentOp = yieldOp->getParentOp(); +static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser, + OperationState &state) { + Region *body = state.addRegion(); - if (!parentOp || !isa(parentOp)) - return yieldOp.emitOpError( - "expected parent op to be 'spv.SpecConstantOperation'"); + if (parser.parseKeyword("wraps")) + return failure(); - Block &block = parentOp->getRegion(0).getBlocks().front(); - Operation &enclosedOp = block.getOperations().front(); + body->push_back(new Block); + Block &block = body->back(); + Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); - if (yieldOp.getOperand().getDefiningOp() != &enclosedOp) - return yieldOp.emitOpError( - "expected operand to be defined by preceeding op"); + if (!wrappedOp) + return failure(); - return success(); -} + OpBuilder builder(parser.getBuilder().getContext()); + builder.setInsertionPointToEnd(&block); + builder.create(wrappedOp->getLoc(), wrappedOp->getResult(0)); + state.location = wrappedOp->getLoc(); -static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser, - OperationState &state) { - // TODO: For now, only generic form is supported. - return failure(); + state.addTypes(wrappedOp->getResult(0).getType()); + + if (parser.parseOptionalAttrDict(state.attributes)) + return failure(); + + return success(); } static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) { - // TODO - printer.printGenericOp(op); + printer << op.getOperationName() << " wraps "; + printer.printGenericOp(&op.body().front().front()); } static LogicalResult verify(spirv::SpecConstantOperationOp constOp) { @@ -3433,11 +3437,6 @@ if (block.getOperations().size() != 2) return constOp.emitOpError("expected exactly 2 nested ops"); - Operation &yieldOp = block.getOperations().back(); - - if (!isa(yieldOp)) - return constOp.emitOpError("expected terminator to be a yield op"); - Operation &enclosedOp = block.getOperations().front(); // TODO Add a `UsableInSpecConstantOp` trait and mark ops from the list below @@ -3457,21 +3456,12 @@ spirv::UGreaterThanEqualOp, spirv::SGreaterThanEqualOp>(enclosedOp)) return constOp.emitOpError("invalid enclosed op"); - if (enclosedOp.getNumOperands() != constOp.getOperands().size()) - return constOp.emitOpError("invalid number of operands; expected ") - << enclosedOp.getNumOperands() << ", actual " - << constOp.getOperands().size(); - - if (enclosedOp.getNumOperands() != constOp.getRegion().getNumArguments()) - return constOp.emitOpError("invalid number of region arguments; expected ") - << enclosedOp.getNumOperands() << ", actual " - << constOp.getRegion().getNumArguments(); - - for (auto operand : constOp.getOperands()) + for (auto operand : enclosedOp.getOperands()) if (!isa( operand.getDefiningOp())) - return constOp.emitOpError("invalid operand"); + return constOp.emitOpError( + "invalid operand, must be defined by a constant operation"); return success(); } diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -757,6 +757,7 @@ // expected-error @+1 {{unsupported composite type}} spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device> } + //===----------------------------------------------------------------------===// // spv.SpecConstantOperation //===----------------------------------------------------------------------===// @@ -765,34 +766,15 @@ spv.module Logical GLSL450 { spv.func @foo() -> i32 "None" { + // CHECK: [[LHS:%.*]] = spv.constant %0 = spv.constant 1: i32 - %2 = spv.constant 1: i32 - - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - %2 = spv.constant 1: i32 + // CHECK: [[RHS:%.*]] = spv.constant + %1 = spv.constant 1: i32 - // expected-error @+1 {{invalid number of operands; expected 2, actual 1}} - %1 = "spv.SpecConstantOperation"(%0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32) -> i32 + // CHECK: spv.SpecConstantOperation wraps "spv.IAdd"([[LHS]], [[RHS]]) : (i32, i32) -> i32 + %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32 - spv.ReturnValue %1 : i32 + spv.ReturnValue %2 : i32 } } @@ -801,93 +783,20 @@ spv.module Logical GLSL450 { spv.func @foo() -> i32 "None" { %0 = spv.constant 1: i32 - %2 = spv.constant 1: i32 - - // expected-error @+1 {{invalid number of region arguments; expected 2, actual 1}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32): - %ret = spv.IAdd %lhs, %lhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - // expected-error @+1 {{expected parent op to be 'spv.SpecConstantOperation'}} + // expected-error @+1 {{op expects parent op 'spv.SpecConstantOperation'}} spv.mlir.yield %0 : i32 } } // ----- -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.ISub %lhs, %rhs : i32 - // expected-error @+1 {{expected operand to be defined by preceeding op}} - spv.mlir.yield %lhs : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - - // expected-error @+1 {{expected exactly 2 nested ops}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - %ret2 = spv.IAdd %lhs, %rhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - -spv.module Logical GLSL450 { - spv.func @foo() -> i32 "None" { - %0 = spv.constant 1: i32 - - // expected-error @+1 {{expected terminator to be a yield op}} - %1 = "spv.SpecConstantOperation"(%0, %0) ({ - ^bb(%lhs : i32, %rhs : i32): - %ret = spv.IAdd %lhs, %rhs : i32 - spv.ReturnValue %ret : i32 - }) : (i32, i32) -> i32 - - spv.ReturnValue %1 : i32 - } -} - -// ----- - spv.module Logical GLSL450 { spv.func @foo() -> () "None" { %0 = spv.Variable : !spv.ptr // expected-error @+1 {{invalid enclosed op}} - %2 = "spv.SpecConstantOperation"(%0) ({ - ^bb(%arg0 : !spv.ptr): - %ret = spv.Load "Function" %arg0 : i32 - spv.mlir.yield %ret : i32 - }) : (!spv.ptr) -> i32 + %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr) -> i32 + spv.Return } } @@ -898,11 +807,9 @@ %0 = spv.Variable : !spv.ptr %1 = spv.Load "Function" %0 : i32 - // expected-error @+1 {{invalid operand}} - %2 = "spv.SpecConstantOperation"(%1, %1) ({ - ^bb(%lhs: i32, %rhs: i32): - %ret = spv.IAdd %lhs, %lhs : i32 - spv.mlir.yield %ret : i32 - }) : (i32, i32) -> i32 + // expected-error @+1 {{invalid operand, must be defined by a constant operation}} + %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%1, %1) : (i32, i32) -> i32 + + spv.Return } }