diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -158,7 +158,6 @@ // of strings and signed/unsigned integers (for now) as an artefact of // splitting the Standard dialect. let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result); - let verifier = [{ return ::verify(*this); }]; let builders = [ OpBuilder<(ins "Attribute":$value), @@ -175,6 +174,7 @@ let hasFolder = 1; let assemblyFormat = "attr-dict $value"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -814,7 +814,7 @@ }]; let hasFolder = 1; - let verifier = [{ return verifyExtOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -845,7 +845,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - let verifier = [{ return verifyExtOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -859,8 +859,7 @@ The destination type must to be strictly wider than the source type. When operating on vectors, casts elementwise. }]; - - let verifier = [{ return verifyExtOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -887,7 +886,7 @@ }]; let hasFolder = 1; - let verifier = [{ return verifyTruncateOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -904,7 +903,7 @@ }]; let hasFolder = 1; - let verifier = [{ return verifyTruncateOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -83,9 +83,9 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; let skipDefaultBuilders = 1; + let hasVerifier = 1; let builders = [ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$dependencies, "ValueRange":$operands, @@ -111,10 +111,8 @@ }]; let arguments = (ins Variadic:$operands); - let assemblyFormat = "($operands^ `:` type($operands))? attr-dict"; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Async_AwaitOp : Async_Op<"await"> { @@ -138,6 +136,7 @@ let results = (outs Optional:$result); let skipDefaultBuilders = 1; + let hasVerifier = 1; let builders = [ OpBuilder<(ins "Value":$operand, @@ -156,8 +155,6 @@ type($operand), type($result) ) attr-dict }]; - - let verifier = [{ return ::verify(*this); }]; } def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> { diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -58,7 +58,6 @@ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; let hasFolder = 1; - let verifier = ?; let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -97,8 +96,6 @@ let arguments = (ins Arg:$memref); let results = (outs AnyTensor:$result); - // MemrefToTensor is fully verified by traits. - let verifier = ?; let builders = [ OpBuilder<(ins "Value":$memref), [{ @@ -184,8 +181,6 @@ let arguments = (ins AnyTensor:$tensor); let results = (outs AnyRankedOrUnrankedMemRef:$memref); - // This op is fully verified by traits. - let verifier = ?; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -26,7 +26,6 @@ let arguments = (ins Complex:$lhs, Complex:$rhs); let results = (outs Complex:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; - let verifier = ?; } // Base class for standard unary operations on complex numbers with a @@ -36,7 +35,6 @@ Complex_Op { let arguments = (ins Complex:$complex); let assemblyFormat = "$complex attr-dict `:` type($complex)"; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -102,7 +100,7 @@ let assemblyFormat = "$value attr-dict `:` type($complex)"; let hasFolder = 1; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; let extraClassDeclaration = [{ /// Returns true if a constant operation can be built with the given value diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -24,9 +24,7 @@ // Base class for EmitC dialect ops. class EmitC_Op traits = []> - : Op { - let verifier = "return ::verify(*this);"; -} + : Op; def EmitC_ApplyOp : EmitC_Op<"apply", []> { let summary = "Apply operation"; @@ -54,6 +52,7 @@ let assemblyFormat = [{ $applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results) }]; + let hasVerifier = 1; } def EmitC_CallOp : EmitC_Op<"call", []> { @@ -85,6 +84,7 @@ let assemblyFormat = [{ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) }]; + let hasVerifier = 1; } def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { @@ -113,6 +113,7 @@ let results = (outs AnyType); let hasFolder = 1; + let hasVerifier = 1; } def EmitC_IncludeOp @@ -144,7 +145,6 @@ ); let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = ?; } #endif // MLIR_DIALECT_EMITC_IR_EMITC diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -25,12 +25,10 @@ Op { // For every linalg op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -54,8 +52,6 @@ `:` type($result) }]; - let verifier = [{ return ::verify(*this); }]; - let extraClassDeclaration = [{ static StringRef getStaticSizesAttrName() { return "static_sizes"; @@ -127,6 +123,7 @@ ]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, @@ -144,6 +141,7 @@ ``` }]; let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + let hasVerifier = 1; } def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [ @@ -426,6 +424,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Linalg_IndexOp : Linalg_Op<"index", [NoSideEffect]>, @@ -469,6 +468,7 @@ }]; let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; + let hasVerifier = 1; } #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -101,10 +101,9 @@ OpBuilder<(ins "Value":$value, "Value":$output)> ]; - let verifier = [{ return ::verify(*this); }]; - let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -264,10 +263,9 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseGenericOp(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; - let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/QuantOps.td --- a/mlir/include/mlir/Dialect/Quant/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/QuantOps.td @@ -103,7 +103,7 @@ StrAttr:$logical_kernel); let results = (outs Variadic:$outputs); let regions = (region SizedRegion<1>:$body); - let verifier = [{ return verifyRegionOp(*this); }]; + let hasVerifier = 1; } def quant_ReturnOp : quant_Op<"return", [Terminator]> { @@ -227,43 +227,7 @@ OptionalAttr:$axisStats, OptionalAttr:$axis); let results = (outs quant_RealValueType); - - let verifier = [{ - auto tensorArg = arg().getType().dyn_cast(); - if (!tensorArg) return emitOpError("arg needs to be tensor type."); - - // Verify layerStats attribute. - { - auto layerStatsType = layerStats().getType(); - if (!layerStatsType.getElementType().isa()) { - return emitOpError( - "layerStats must have a floating point element type"); - } - if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { - return emitOpError("layerStats must have shape [2]"); - } - } - // Verify axisStats (optional) attribute. - if (axisStats()) { - if (!axis()) return emitOpError("axis must be specified for axisStats"); - - auto shape = tensorArg.getShape(); - auto argSliceSize = std::accumulate(std::next(shape.begin(), - *axis()), shape.end(), 1, std::multiplies()); - - auto axisStatsType = axisStats()->getType(); - if (!axisStatsType.getElementType().isa()) { - return emitOpError("axisStats must have a floating point element type"); - } - if (axisStatsType.getRank() != 2 || - axisStatsType.getDimSize(1) != 2 || - axisStatsType.getDimSize(0) != argSliceSize) { - return emitOpError("axisStats must have shape [N,2] " - "where N = the slice size defined by the axis dim"); - } - } - return success(); - }]; + let hasVerifier = 1; } def quant_CoupledRefOp : quant_Op<"coupled_ref", [SameOperandsAndResultType]> { diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -48,8 +48,6 @@ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; - let verifier = [{ return verifySizeOrIndexOp(*this); }]; - let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by // InferTypeOpInterface @@ -57,6 +55,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> { @@ -102,7 +101,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Shape_ConstShapeOp : Shape_Op<"const_shape", @@ -184,8 +183,8 @@ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -323,7 +322,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -377,7 +376,7 @@ }]; let hasFolder = 1; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let hasVerifier = 1; } def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> { @@ -518,8 +517,8 @@ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -545,7 +544,7 @@ let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)"; let hasFolder = 1; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by // InferTypeOpInterface @@ -595,7 +594,7 @@ let builders = [OpBuilder<(ins "Value":$shape, "ValueRange":$initVals)>]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -614,9 +613,9 @@ let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; - let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -724,8 +723,8 @@ [{ build($_builder, $_state, llvm::None); }]> ]; - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let hasVerifier = 1; } // TODO: Add Ops: if_static, if_ranked @@ -859,8 +858,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Shape_AssumingOp : Shape_Op<"assuming", [ @@ -882,7 +880,6 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; let extraClassDeclaration = [{ // Inline the region into the region containing the AssumingOp and delete @@ -898,6 +895,7 @@ ]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", @@ -959,7 +957,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> { diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -35,12 +35,10 @@ Op { // For every standard op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -66,10 +64,6 @@ let arguments = (ins I1:$arg, StrAttr:$msg); let assemblyFormat = "$arg `,` $msg attr-dict"; - - // AssertOp is fully verified by its traits. - let verifier = ?; - let hasCanonicalizeMethod = 1; } @@ -107,9 +101,6 @@ $_state.addOperands(destOperands); }]>]; - // BranchOp is fully verified by traits. - let verifier = ?; - let extraClassDeclaration = [{ void setDest(Block *block); @@ -189,7 +180,6 @@ let assemblyFormat = [{ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) }]; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -250,7 +240,6 @@ CallInterfaceCallable getCallableForCallee() { return getCallee(); } }]; - let verifier = ?; let hasCanonicalizeMethod = 1; let assemblyFormat = @@ -311,9 +300,6 @@ falseOperands); }]>]; - // CondBranchOp is fully verified by traits. - let verifier = ?; - let extraClassDeclaration = [{ // These are the indices into the dests list. enum { trueIndex = 0, falseIndex = 1 }; @@ -434,6 +420,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -466,6 +453,7 @@ [{ build($_builder, $_state, llvm::None); }]>]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -516,6 +504,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -562,6 +551,7 @@ [{ build($_builder, $_state, aggregateType, element); }]>]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$input attr-dict `:` type($aggregate)"; } @@ -652,6 +642,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } #endif // STANDARD_OPS diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -82,8 +82,7 @@ ); let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; - - let verifier = [{ return verifyAveragePoolOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -116,8 +115,7 @@ ); let builders = [Tosa_ConvOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -149,8 +147,7 @@ ); let builders = [Tosa_ConvOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -183,8 +180,7 @@ ); let builders = [Tosa_ConvOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -212,8 +208,7 @@ ); let builders = [Tosa_FCOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -157,7 +157,7 @@ }]; let parser = [{ return ::parseFuncOp(parser, result); }]; let printer = [{ return ::print(*this, p); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -221,7 +221,7 @@ return "builtin"; } }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; // We need to ensure the block inside the region is properly terminated; // the auto-generated builders do not guarantee that. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2495,12 +2495,10 @@ // For every such op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; code extraBaseClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -107,19 +107,19 @@ /// TODO: disallow arith.constant to return anything other than signless integer /// or float like. -static LogicalResult verify(arith::ConstantOp op) { - auto type = op.getType(); +LogicalResult arith::ConstantOp::verify() { + auto type = getType(); // The value's type must match the return type. - if (op.getValue().getType() != type) { - return op.emitOpError() << "value type " << op.getValue().getType() - << " must match return type: " << type; + if (getValue().getType() != type) { + return emitOpError() << "value type " << getValue().getType() + << " must match return type: " << type; } // Integer values must be signless. if (type.isa() && !type.cast().isSignless()) - return op.emitOpError("integer return type must be signless"); + return emitOpError("integer return type must be signless"); // Any float or elements attribute are acceptable. - if (!op.getValue().isa()) { - return op.emitOpError( + if (!getValue().isa()) { + return emitOpError( "value must be an integer, float, or elements attribute"); } return success(); @@ -886,6 +886,10 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::ExtUIOp::verify() { + return verifyExtOp(*this); +} + //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// @@ -912,6 +916,10 @@ patterns.insert(context); } +LogicalResult arith::ExtSIOp::verify() { + return verifyExtOp(*this); +} + //===----------------------------------------------------------------------===// // ExtFOp //===----------------------------------------------------------------------===// @@ -920,6 +928,8 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::ExtFOp::verify() { return verifyExtOp(*this); } + //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// @@ -954,6 +964,10 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::TruncIOp::verify() { + return verifyTruncateOp(*this); +} + //===----------------------------------------------------------------------===// // TruncFOp //===----------------------------------------------------------------------===// @@ -983,6 +997,10 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::TruncFOp::verify() { + return verifyTruncateOp(*this); +} + //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -33,17 +33,17 @@ // YieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(YieldOp op) { +LogicalResult YieldOp::verify() { // Get the underlying value types from async values returned from the // parent `async.execute` operation. - auto executeOp = op->getParentOfType(); + auto executeOp = (*this)->getParentOfType(); auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { return result.getType().cast().getValueType(); }); - if (op.getOperandTypes() != types) - return op.emitOpError("operand types do not match the types returned from " - "the parent ExecuteOp"); + if (getOperandTypes() != types) + return emitOpError("operand types do not match the types returned from " + "the parent ExecuteOp"); return success(); } @@ -228,16 +228,16 @@ return success(); } -static LogicalResult verify(ExecuteOp op) { +LogicalResult ExecuteOp::verify() { // Unwrap async.execute value operands types. - auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { + auto unwrappedTypes = llvm::map_range(operands(), [](Value operand) { return operand.getType().cast().getValueType(); }); // Verify that unwrapped argument types matches the body region arguments. - if (op.body().getArgumentTypes() != unwrappedTypes) - return op.emitOpError("async body region argument types do not match the " - "execute operation arguments types"); + if (body().getArgumentTypes() != unwrappedTypes) + return emitOpError("async body region argument types do not match the " + "execute operation arguments types"); return success(); } @@ -303,19 +303,19 @@ p << operandType; } -static LogicalResult verify(AwaitOp op) { - Type argType = op.operand().getType(); +LogicalResult AwaitOp::verify() { + Type argType = operand().getType(); // Awaiting on a token does not have any results. - if (argType.isa() && !op.getResultTypes().empty()) - return op.emitOpError("awaiting on a token must have empty result"); + if (argType.isa() && !getResultTypes().empty()) + return emitOpError("awaiting on a token must have empty result"); // Awaiting on a value unwraps the async value type. if (auto value = argType.dyn_cast()) { - if (*op.getResultType() != value.getValueType()) - return op.emitOpError() - << "result type " << *op.getResultType() - << " does not match async value type " << value.getValueType(); + if (*getResultType() != value.getValueType()) + return emitOpError() << "result type " << *getResultType() + << " does not match async value type " + << value.getValueType(); } return success(); diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -38,18 +38,18 @@ return false; } -static LogicalResult verify(ConstantOp op) { - ArrayAttr arrayAttr = op.getValue(); +LogicalResult ConstantOp::verify() { + ArrayAttr arrayAttr = getValue(); if (arrayAttr.size() != 2) { - return op.emitOpError( + return emitOpError( "requires 'value' to be a complex constant, represented as array of " "two values"); } - auto complexEltTy = op.getType().getElementType(); + auto complexEltTy = getType().getElementType(); if (complexEltTy != arrayAttr[0].getType() || complexEltTy != arrayAttr[1].getType()) { - return op.emitOpError() + return emitOpError() << "requires attribute's element types (" << arrayAttr[0].getType() << ", " << arrayAttr[1].getType() << ") to match the element type of the op's return type (" diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -48,16 +48,16 @@ // ApplyOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ApplyOp op) { - StringRef applicableOperator = op.applicableOperator(); +LogicalResult ApplyOp::verify() { + StringRef applicableOperatorStr = applicableOperator(); // Applicable operator must not be empty. - if (applicableOperator.empty()) - return op.emitOpError("applicable operator must not be empty"); + if (applicableOperatorStr.empty()) + return emitOpError("applicable operator must not be empty"); // Only `*` and `&` are supported. - if (applicableOperator != "&" && applicableOperator != "*") - return op.emitOpError("applicable operator is illegal"); + if (applicableOperatorStr != "&" && applicableOperatorStr != "*") + return emitOpError("applicable operator is illegal"); return success(); } @@ -66,32 +66,32 @@ // CallOp //===----------------------------------------------------------------------===// -static LogicalResult verify(emitc::CallOp op) { +LogicalResult emitc::CallOp::verify() { // Callee must not be empty. - if (op.callee().empty()) - return op.emitOpError("callee must not be empty"); + if (callee().empty()) + return emitOpError("callee must not be empty"); - if (Optional argsAttr = op.args()) { + if (Optional argsAttr = args()) { for (Attribute arg : argsAttr.getValue()) { if (arg.getType().isa()) { int64_t index = arg.cast().getInt(); // Args with elements of type index must be in range // [0..operands.size). - if ((index < 0) || (index >= static_cast(op.getNumOperands()))) - return op.emitOpError("index argument is out of range"); + if ((index < 0) || (index >= static_cast(getNumOperands()))) + return emitOpError("index argument is out of range"); // Args with elements of type ArrayAttr must have a type. } else if (arg.isa() && arg.getType().isa()) { - return op.emitOpError("array argument has no type"); + return emitOpError("array argument has no type"); } } } - if (Optional templateArgsAttr = op.template_args()) { + if (Optional templateArgsAttr = template_args()) { for (Attribute tArg : templateArgsAttr.getValue()) { if (!tArg.isa() && !tArg.isa() && !tArg.isa() && !tArg.isa()) - return op.emitOpError("template argument has invalid type"); + return emitOpError("template argument has invalid type"); } } @@ -103,12 +103,12 @@ //===----------------------------------------------------------------------===// /// The constant op requires that the attribute's type matches the return type. -static LogicalResult verify(emitc::ConstantOp &op) { - Attribute value = op.value(); - Type type = op.getType(); +LogicalResult emitc::ConstantOp::verify() { + Attribute value = valueAttr(); + Type type = getType(); if (!value.getType().isa() && type != value.getType()) - return op.emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; + return emitOpError() << "requires attribute's type (" << value.getType() + << ") to match op's return type (" << type << ")"; return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -400,11 +400,11 @@ /// FillOp region is elided when printing. void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} -static LogicalResult verify(FillOp op) { - OpOperand *output = op.getOutputOperand(0); - Type fillType = op.value().getType(); +LogicalResult FillOp::verify() { + OpOperand *output = getOutputOperand(0); + Type fillType = value().getType(); if (getElementTypeOrSelf(output->get()) != fillType) - return op.emitOpError("expects fill type to match view elemental type"); + return emitOpError("expects fill type to match view elemental type"); return success(); } @@ -635,7 +635,7 @@ return success(); } -static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } +LogicalResult GenericOp::verify() { return verifyGenericOp(*this); } namespace { // Deduplicate redundant args of a linalg generic op. @@ -811,25 +811,24 @@ result.addAttributes(attrs); } -static LogicalResult verify(InitTensorOp op) { - RankedTensorType resultType = op.getType(); +LogicalResult InitTensorOp::verify() { + RankedTensorType resultType = getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( - op.static_sizes().cast(), + static_sizes().cast(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); - if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(), - op.static_sizes(), op.sizes(), - ShapedType::isDynamic))) + if (failed(verifyListOfOperandsOrIntegers( + *this, "sizes", resultType.getRank(), static_sizes(), sizes(), + ShapedType::isDynamic))) return failure(); - if (op.static_sizes().size() != static_cast(resultType.getRank())) - return op->emitError("expected ") - << resultType.getRank() << " sizes values"; + if (static_sizes().size() != static_cast(resultType.getRank())) + return emitError("expected ") << resultType.getRank() << " sizes values"; Type expectedType = InitTensorOp::inferResultType( staticSizes, resultType.getElementType(), resultType.getEncoding()); if (resultType != expectedType) { - return op.emitError("specified type ") + return emitError("specified type ") << resultType << " does not match the inferred type " << expectedType; } @@ -1030,13 +1029,13 @@ return success(); } -static LogicalResult verify(linalg::YieldOp op) { - auto *parentOp = op->getParentOp(); +LogicalResult linalg::YieldOp::verify() { + auto *parentOp = (*this)->getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) - return op.emitOpError("expected single non-empty parent region"); + return emitOpError("expected single non-empty parent region"); if (auto linalgOp = dyn_cast(parentOp)) - return verifyYield(op, cast(parentOp)); + return verifyYield(*this, cast(parentOp)); if (auto tiledLoopOp = dyn_cast(parentOp)) { // Check if output args with tensor types match results types. @@ -1044,25 +1043,25 @@ llvm::copy_if( tiledLoopOp.outputs(), std::back_inserter(tensorOuts), [&](Value out) { return out.getType().isa(); }); - if (tensorOuts.size() != op.values().size()) - return op.emitOpError("expected number of tensor output args = ") - << tensorOuts.size() << " to match the number of yield operands = " - << op.values().size(); + if (tensorOuts.size() != values().size()) + return emitOpError("expected number of tensor output args = ") + << tensorOuts.size() + << " to match the number of yield operands = " << values().size(); TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts)); for (auto &item : - llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) { + llvm::enumerate(llvm::zip(tensorTypes, getOperandTypes()))) { Type outType, resultType; unsigned index = item.index(); std::tie(outType, resultType) = item.value(); if (outType != resultType) - return op.emitOpError("expected yield operand ") + return emitOpError("expected yield operand ") << index << " with type = " << resultType << " to match output arg type = " << outType; } return success(); } - return op.emitOpError("expected parent op with LinalgOp interface"); + return emitOpError("expected parent op with LinalgOp interface"); } //===----------------------------------------------------------------------===// @@ -1316,37 +1315,37 @@ return !region().isAncestor(value.getParentRegion()); } -static LogicalResult verify(TiledLoopOp op) { +LogicalResult TiledLoopOp::verify() { // Check if iterator types are provided for every loop dimension. - if (op.iterator_types().size() != op.getNumLoops()) - return op.emitOpError("expected iterator types array attribute size = ") - << op.iterator_types().size() - << " to match the number of loops = " << op.getNumLoops(); + if (iterator_types().size() != getNumLoops()) + return emitOpError("expected iterator types array attribute size = ") + << iterator_types().size() + << " to match the number of loops = " << getNumLoops(); // Check if types of input arguments match region args types. for (auto &item : - llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) { + llvm::enumerate(llvm::zip(inputs(), getRegionInputArgs()))) { Value input, inputRegionArg; unsigned index = item.index(); std::tie(input, inputRegionArg) = item.value(); if (input.getType() != inputRegionArg.getType()) - return op.emitOpError("expected input arg ") + return emitOpError("expected input arg ") << index << " with type = " << input.getType() - << " to match region arg " << index + op.getNumLoops() + << " to match region arg " << index + getNumLoops() << " type = " << inputRegionArg.getType(); } // Check if types of input arguments match region args types. for (auto &item : - llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) { + llvm::enumerate(llvm::zip(outputs(), getRegionOutputArgs()))) { Value output, outputRegionArg; unsigned index = item.index(); std::tie(output, outputRegionArg) = item.value(); if (output.getType() != outputRegionArg.getType()) - return op.emitOpError("expected output arg ") + return emitOpError("expected output arg ") << index << " with type = " << output.getType() << " to match region arg " - << index + op.getNumLoops() + op.inputs().size() + << index + getNumLoops() + inputs().size() << " type = " << outputRegionArg.getType(); } return success(); @@ -1667,13 +1666,13 @@ // IndexOp //===----------------------------------------------------------------------===// -static LogicalResult verify(IndexOp op) { - auto linalgOp = dyn_cast(op->getParentOp()); +LogicalResult IndexOp::verify() { + auto linalgOp = dyn_cast((*this)->getParentOp()); if (!linalgOp) - return op.emitOpError("expected parent op with LinalgOp interface"); - if (linalgOp.getNumLoops() <= op.dim()) - return op.emitOpError("expected dim (") - << op.dim() << ") to be lower than the number of loops (" + return emitOpError("expected parent op with LinalgOp interface"); + if (linalgOp.getNumLoops() <= dim()) + return emitOpError("expected dim (") + << dim() << ") to be lower than the number of loops (" << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; return success(); } diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -65,29 +65,67 @@ return false; } -static LogicalResult verifyRegionOp(QuantizeRegionOp op) { +LogicalResult QuantizeRegionOp::verify() { // There are specifications for both inputs and outputs. - if (op.getNumOperands() != op.input_specs().size() || - op.getNumResults() != op.output_specs().size()) - return op.emitOpError( + if (getNumOperands() != input_specs().size() || + getNumResults() != output_specs().size()) + return emitOpError( "has unmatched operands/results number and spec attributes number"); // Verify that quantization specifications are valid. - for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) { + for (auto input : llvm::zip(getOperandTypes(), input_specs())) { Type inputType = std::get<0>(input); Attribute inputSpec = std::get<1>(input); if (!isValidQuantizationSpec(inputSpec, inputType)) { - return op.emitOpError() << "has incompatible specification " << inputSpec - << " and input type " << inputType; + return emitOpError() << "has incompatible specification " << inputSpec + << " and input type " << inputType; } } - for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) { + for (auto result : llvm::zip(getResultTypes(), output_specs())) { Type outputType = std::get<0>(result); Attribute outputSpec = std::get<1>(result); if (!isValidQuantizationSpec(outputSpec, outputType)) { - return op.emitOpError() << "has incompatible specification " << outputSpec - << " and output type " << outputType; + return emitOpError() << "has incompatible specification " << outputSpec + << " and output type " << outputType; + } + } + return success(); +} + +LogicalResult StatisticsOp::verify() { + auto tensorArg = arg().getType().dyn_cast(); + if (!tensorArg) + return emitOpError("arg needs to be tensor type."); + + // Verify layerStats attribute. + { + auto layerStatsType = layerStats().getType(); + if (!layerStatsType.getElementType().isa()) { + return emitOpError("layerStats must have a floating point element type"); + } + if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { + return emitOpError("layerStats must have shape [2]"); + } + } + // Verify axisStats (optional) attribute. + if (axisStats()) { + if (!axis()) + return emitOpError("axis must be specified for axisStats"); + + auto shape = tensorArg.getShape(); + auto argSliceSize = + std::accumulate(std::next(shape.begin(), *axis()), shape.end(), 1, + std::multiplies()); + + auto axisStatsType = axisStats()->getType(); + if (!axisStatsType.getElementType().isa()) { + return emitOpError("axisStats must have a floating point element type"); + } + if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 || + axisStatsType.getDimSize(0) != argSliceSize) { + return emitOpError("axisStats must have shape [N,2] " + "where N = the slice size defined by the axis dim"); } } return success(); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -419,6 +419,10 @@ result.addTypes(assumingTypes); } +LogicalResult AssumingOp::verify() { + return RegionBranchOpInterface::verifyTypes(*this); +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -449,6 +453,8 @@ operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); } +LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// @@ -529,10 +535,10 @@ return BoolAttr::get(getContext(), true); } -static LogicalResult verify(AssumingAllOp op) { +LogicalResult AssumingAllOp::verify() { // Ensure that AssumingAllOp contains at least one operand - if (op.getNumOperands() == 0) - return op.emitOpError("no operands specified"); + if (getNumOperands() == 0) + return emitOpError("no operands specified"); return success(); } @@ -575,8 +581,8 @@ return builder.getIndexTensorAttr(resultShape); } -static LogicalResult verify(BroadcastOp op) { - return verifyShapeOrExtentTensorOp(op); +LogicalResult BroadcastOp::verify() { + return verifyShapeOrExtentTensorOp(*this); } namespace { @@ -912,10 +918,10 @@ return nullptr; } -static LogicalResult verify(CstrBroadcastableOp op) { +LogicalResult CstrBroadcastableOp::verify() { // Ensure that AssumingAllOp contains at least one operand - if (op.getNumOperands() < 2) - return op.emitOpError("required at least 2 input shapes"); + if (getNumOperands() < 2) + return emitOpError("required at least 2 input shapes"); return success(); } @@ -1016,6 +1022,8 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // ShapeEqOp //===----------------------------------------------------------------------===// @@ -1172,6 +1180,8 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // IsBroadcastableOp //===----------------------------------------------------------------------===// @@ -1298,6 +1308,8 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// @@ -1333,6 +1345,10 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult shape::NumElementsOp::verify() { + return verifySizeOrIndexOp(*this); +} + //===----------------------------------------------------------------------===// // MaxOp //===----------------------------------------------------------------------===// @@ -1429,6 +1445,9 @@ // SizeType is compatible with IndexType. return eachHasOnlyOneOfTypes(l, r); } + +LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// @@ -1535,6 +1554,10 @@ return false; } +LogicalResult shape::ShapeOfOp::verify() { + return verifyShapeOrExtentTensorOp(*this); +} + //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// @@ -1556,18 +1579,17 @@ // YieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(shape::YieldOp op) { - auto *parentOp = op->getParentOp(); +LogicalResult shape::YieldOp::verify() { + auto *parentOp = (*this)->getParentOp(); auto results = parentOp->getResults(); - auto operands = op.getOperands(); + auto operands = getOperands(); - if (parentOp->getNumResults() != op.getNumOperands()) - return op.emitOpError() << "number of operands does not match number of " - "results of its parent"; + if (parentOp->getNumResults() != getNumOperands()) + return emitOpError() << "number of operands does not match number of " + "results of its parent"; for (auto e : llvm::zip(results, operands)) if (std::get<0>(e).getType() != std::get<1>(e).getType()) - return op.emitOpError() - << "types mismatch between yield op and its parent"; + return emitOpError() << "types mismatch between yield op and its parent"; return success(); } @@ -1639,41 +1661,42 @@ } } -static LogicalResult verify(ReduceOp op) { +LogicalResult ReduceOp::verify() { // Verify block arg types. - Block &block = op.getRegion().front(); + Block &block = getRegion().front(); // The block takes index, extent, and aggregated values as arguments. - auto blockArgsCount = op.getInitVals().size() + 2; + auto blockArgsCount = getInitVals().size() + 2; if (block.getNumArguments() != blockArgsCount) - return op.emitOpError() << "ReduceOp body is expected to have " - << blockArgsCount << " arguments"; + return emitOpError() << "ReduceOp body is expected to have " + << blockArgsCount << " arguments"; // The first block argument is the index and must always be of type `index`. if (!block.getArgument(0).getType().isa()) - return op.emitOpError( + return emitOpError( "argument 0 of ReduceOp body is expected to be of IndexType"); // The second block argument is the extent and must be of type `size` or // `index`, depending on whether the reduce operation is applied to a shape or // to an extent tensor. Type extentTy = block.getArgument(1).getType(); - if (op.getShape().getType().isa()) { + if (getShape().getType().isa()) { if (!extentTy.isa()) - return op.emitOpError("argument 1 of ReduceOp body is expected to be of " - "SizeType if the ReduceOp operates on a ShapeType"); + return emitOpError("argument 1 of ReduceOp body is expected to be of " + "SizeType if the ReduceOp operates on a ShapeType"); } else { if (!extentTy.isa()) - return op.emitOpError( + return emitOpError( "argument 1 of ReduceOp body is expected to be of IndexType if the " "ReduceOp operates on an extent tensor"); } - for (const auto &type : llvm::enumerate(op.getInitVals())) + for (const auto &type : llvm::enumerate(getInitVals())) if (block.getArgument(type.index() + 2).getType() != type.value().getType()) - return op.emitOpError() - << "type mismatch between argument " << type.index() + 2 - << " of ReduceOp body and initial value " << type.index(); + return emitOpError() << "type mismatch between argument " + << type.index() + 2 + << " of ReduceOp body and initial value " + << type.index(); return success(); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -612,31 +612,31 @@ /// The constant op requires an attribute, and furthermore requires that it /// matches the return type. -static LogicalResult verify(ConstantOp &op) { - auto value = op.getValue(); +LogicalResult ConstantOp::verify() { + auto value = getValue(); if (!value) - return op.emitOpError("requires a 'value' attribute"); + return emitOpError("requires a 'value' attribute"); - Type type = op.getType(); + Type type = getType(); if (!value.getType().isa() && type != value.getType()) - return op.emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; + return emitOpError() << "requires attribute's type (" << value.getType() + << ") to match op's return type (" << type << ")"; if (type.isa()) { auto fnAttr = value.dyn_cast(); if (!fnAttr) - return op.emitOpError("requires 'value' to be a function reference"); + return emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. - auto fn = - op->getParentOfType().lookupSymbol(fnAttr.getValue()); + auto fn = (*this)->getParentOfType().lookupSymbol( + fnAttr.getValue()); if (!fn) - return op.emitOpError() - << "reference to undefined function '" << fnAttr.getValue() << "'"; + return emitOpError() << "reference to undefined function '" + << fnAttr.getValue() << "'"; // Check that the referenced function has the correct type. if (fn.getType() != type) - return op.emitOpError("reference to function with mismatched type"); + return emitOpError("reference to function with mismatched type"); return success(); } @@ -644,7 +644,7 @@ if (type.isa() && value.isa()) return success(); - return op.emitOpError("unsupported 'value' attribute: ") << value; + return emitOpError("unsupported 'value' attribute: ") << value; } OpFoldResult ConstantOp::fold(ArrayRef operands) { @@ -676,23 +676,23 @@ // ReturnOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReturnOp op) { - auto function = cast(op->getParentOp()); +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); - if (op.getNumOperands() != results.size()) - return op.emitOpError("has ") - << op.getNumOperands() << " operands, but enclosing function (@" + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) - if (op.getOperand(i).getType() != results[i]) - return op.emitError() - << "type of return operand " << i << " (" - << op.getOperand(i).getType() - << ") doesn't match function result type (" << results[i] << ")" - << " in function @" << function.getName(); + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); return success(); } @@ -843,24 +843,23 @@ parser.getNameLoc(), result.operands); } -static LogicalResult verify(SelectOp op) { - Type conditionType = op.getCondition().getType(); +LogicalResult SelectOp::verify() { + Type conditionType = getCondition().getType(); if (conditionType.isSignlessInteger(1)) return success(); // If the result type is a vector or tensor, the type can be a mask with the // same elements. - Type resultType = op.getType(); + Type resultType = getType(); if (!resultType.isa()) - return op.emitOpError() - << "expected condition to be a signless i1, but got " - << conditionType; + return emitOpError() << "expected condition to be a signless i1, but got " + << conditionType; Type shapedConditionType = getI1SameShape(resultType); if (conditionType != shapedConditionType) - return op.emitOpError() - << "expected condition type to have the same shape " - "as the result type, expected " - << shapedConditionType << ", but got " << conditionType; + return emitOpError() << "expected condition type to have the same shape " + "as the result type, expected " + << shapedConditionType << ", but got " + << conditionType; return success(); } @@ -868,11 +867,10 @@ // SplatOp //===----------------------------------------------------------------------===// -static LogicalResult verify(SplatOp op) { +LogicalResult SplatOp::verify() { // TODO: we could replace this by a trait. - if (op.getOperand().getType() != - op.getType().cast().getElementType()) - return op.emitError("operand should be of elemental type of result type"); + if (getOperand().getType() != getType().cast().getElementType()) + return emitError("operand should be of elemental type of result type"); return success(); } @@ -995,26 +993,26 @@ p.printNewline(); } -static LogicalResult verify(SwitchOp op) { - auto caseValues = op.getCaseValues(); - auto caseDestinations = op.getCaseDestinations(); +LogicalResult SwitchOp::verify() { + auto caseValues = getCaseValues(); + auto caseDestinations = getCaseDestinations(); if (!caseValues && caseDestinations.empty()) return success(); - Type flagType = op.getFlag().getType(); + Type flagType = getFlag().getType(); Type caseValueType = caseValues->getType().getElementType(); if (caseValueType != flagType) - return op.emitOpError() - << "'flag' type (" << flagType << ") should match case value type (" - << caseValueType << ")"; + return emitOpError() << "'flag' type (" << flagType + << ") should match case value type (" << caseValueType + << ")"; if (caseValues && caseValues->size() != static_cast(caseDestinations.size())) - return op.emitOpError() << "number of case values (" << caseValues->size() - << ") should match number of " - "case destinations (" - << caseDestinations.size() << ")"; + return emitOpError() << "number of case values (" << caseValues->size() + << ") should match number of " + "case destinations (" + << caseDestinations.size() << ")"; return success(); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -701,9 +701,9 @@ return success(); } -static LogicalResult verifyAveragePoolOp(tosa::AvgPool2dOp op) { - auto inputETy = op.input().getType().cast().getElementType(); - auto resultETy = op.getType().cast().getElementType(); +LogicalResult tosa::AvgPool2dOp::verify() { + auto inputETy = input().getType().cast().getElementType(); + auto resultETy = getType().cast().getElementType(); if (auto quantType = inputETy.dyn_cast()) inputETy = quantType.getStorageType(); @@ -718,7 +718,7 @@ if (inputETy.isInteger(16) && resultETy.isInteger(16)) return success(); - return op.emitOpError("input/output element types are incompatible."); + return emitOpError("input/output element types are incompatible."); } //===----------------------------------------------------------------------===// @@ -1010,6 +1010,8 @@ return success(); } +LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); } + LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1632,6 +1634,8 @@ return success(); } +LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); } + LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1708,6 +1712,8 @@ return success(); } +LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); } + LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1800,6 +1806,8 @@ return success(); } +LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); } + LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -128,19 +128,19 @@ p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); } -static LogicalResult verify(FuncOp op) { +LogicalResult FuncOp::verify() { // If this function is external there is nothing to do. - if (op.isExternal()) + if (isExternal()) return success(); // Verify that the argument list of the function and the arg list of the entry // block line up. The trait already verified that the number of arguments is // the same between the signature and the block. - auto fnInputTypes = op.getType().getInputs(); - Block &entryBlock = op.front(); + auto fnInputTypes = getType().getInputs(); + Block &entryBlock = front(); for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) - return op.emitOpError("type of entry block argument #") + return emitOpError("type of entry block argument #") << i << '(' << entryBlock.getArgument(i).getType() << ") must match the type of the corresponding argument in " << "function signature(" << fnInputTypes[i] << ')'; @@ -245,28 +245,28 @@ return {}; } -static LogicalResult verify(ModuleOp op) { +LogicalResult ModuleOp::verify() { // Check that none of the attributes are non-dialect attributes, except for // the symbol related attributes. - for (auto attr : op->getAttrs()) { + for (auto attr : (*this)->getAttrs()) { if (!attr.getName().strref().contains('.') && !llvm::is_contained( ArrayRef{mlir::SymbolTable::getSymbolAttrName(), mlir::SymbolTable::getVisibilityAttrName()}, attr.getName().strref())) - return op.emitOpError() << "can only contain attributes with " - "dialect-prefixed names, found: '" - << attr.getName().getValue() << "'"; + return emitOpError() << "can only contain attributes with " + "dialect-prefixed names, found: '" + << attr.getName().getValue() << "'"; } // Check that there is at most one data layout spec attribute. StringRef layoutSpecAttrName; DataLayoutSpecInterface layoutSpec; - for (const NamedAttribute &na : op->getAttrs()) { + for (const NamedAttribute &na : (*this)->getAttrs()) { if (auto spec = na.getValue().dyn_cast()) { if (layoutSpec) { InFlightDiagnostic diag = - op.emitOpError() << "expects at most one data layout attribute"; + emitOpError() << "expects at most one data layout attribute"; diag.attachNote() << "'" << layoutSpecAttrName << "' is a data layout attribute"; diag.attachNote() << "'" << na.getName().getValue() diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -40,10 +40,10 @@ OpBuilder<(ins CArg<"int", "0">:$integer)>]; let parser = [{ foo }]; let printer = [{ bar }]; - let verifier = [{ baz }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Display a graph for debugging purposes.