diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -31,12 +31,10 @@ Op { // For every affine op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -112,6 +110,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def AffineForOp : Affine_Op<"for", @@ -350,6 +349,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def AffineIfOp : Affine_Op<"if", @@ -473,6 +473,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } class AffineLoadOpBase traits = []> : @@ -538,6 +539,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } class AffineMinMaxOpBase traits = []> : @@ -565,11 +567,11 @@ operands().end()}; } }]; - let verifier = [{ return ::verifyAffineMinMaxOp(*this); }]; let printer = [{ return ::printAffineMinMaxOp(p, *this); }]; let parser = [{ return ::parseAffineMinMaxOp<$cppClass>(parser, result); }]; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def AffineMinOp : AffineMinMaxOpBase<"min", [NoSideEffect]> { @@ -753,6 +755,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } def AffinePrefetchOp : Affine_Op<"prefetch", @@ -832,6 +835,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } class AffineStoreOpBase traits = []> : @@ -896,6 +900,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator, ReturnLike, @@ -921,6 +926,7 @@ ]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let hasVerifier = 1; } def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> { @@ -984,6 +990,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> { @@ -1048,6 +1055,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } #endif // AFFINE_OPS diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -158,7 +158,6 @@ // of strings and signed/unsigned integers (for now) as an artefact of // splitting the Standard dialect. let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result); - let verifier = [{ return ::verify(*this); }]; let builders = [ OpBuilder<(ins "Attribute":$value), @@ -175,6 +174,7 @@ let hasFolder = 1; let assemblyFormat = "attr-dict $value"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -814,7 +814,7 @@ }]; let hasFolder = 1; - let verifier = [{ return verifyExtOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -845,7 +845,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - let verifier = [{ return verifyExtOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -859,8 +859,7 @@ The destination type must to be strictly wider than the source type. When operating on vectors, casts elementwise. }]; - - let verifier = [{ return verifyExtOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -887,7 +886,7 @@ }]; let hasFolder = 1; - let verifier = [{ return verifyTruncateOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -904,7 +903,7 @@ }]; let hasFolder = 1; - let verifier = [{ return verifyTruncateOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -82,9 +82,9 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; let skipDefaultBuilders = 1; + let hasVerifier = 1; let builders = [ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$dependencies, "ValueRange":$operands, @@ -110,10 +110,8 @@ }]; let arguments = (ins Variadic:$operands); - let assemblyFormat = "($operands^ `:` type($operands))? attr-dict"; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Async_AwaitOp : Async_Op<"await"> { @@ -137,6 +135,7 @@ let results = (outs Optional:$result); let skipDefaultBuilders = 1; + let hasVerifier = 1; let builders = [ OpBuilder<(ins "Value":$operand, @@ -155,8 +154,6 @@ type($operand), type($result) ) attr-dict }]; - - let verifier = [{ return ::verify(*this); }]; } def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> { diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -58,7 +58,6 @@ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; let hasFolder = 1; - let verifier = ?; let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -97,8 +96,6 @@ let arguments = (ins Arg:$memref); let results = (outs AnyTensor:$result); - // MemrefToTensor is fully verified by traits. - let verifier = ?; let builders = [ OpBuilder<(ins "Value":$memref), [{ @@ -184,8 +181,6 @@ let arguments = (ins AnyTensor:$tensor); let results = (outs AnyRankedOrUnrankedMemRef:$memref); - // This op is fully verified by traits. - let verifier = ?; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -26,7 +26,6 @@ let arguments = (ins Complex:$lhs, Complex:$rhs); let results = (outs Complex:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; - let verifier = ?; } // Base class for standard unary operations on complex numbers with a @@ -36,7 +35,6 @@ Complex_Op { let arguments = (ins Complex:$complex); let assemblyFormat = "$complex attr-dict `:` type($complex)"; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -102,7 +100,7 @@ let assemblyFormat = "$value attr-dict `:` type($complex)"; let hasFolder = 1; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; let extraClassDeclaration = [{ /// Returns true if a constant operation can be built with the given value diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -24,9 +24,7 @@ // Base class for EmitC dialect ops. class EmitC_Op traits = []> - : Op { - let verifier = "return ::verify(*this);"; -} + : Op; def EmitC_ApplyOp : EmitC_Op<"apply", []> { let summary = "Apply operation"; @@ -54,6 +52,7 @@ let assemblyFormat = [{ $applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results) }]; + let hasVerifier = 1; } def EmitC_CallOp : EmitC_Op<"call", []> { @@ -85,6 +84,7 @@ let assemblyFormat = [{ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) }]; + let hasVerifier = 1; } def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { @@ -113,6 +113,7 @@ let results = (outs AnyType); let hasFolder = 1; + let hasVerifier = 1; } def EmitC_IncludeOp @@ -144,7 +145,6 @@ ); let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = ?; } #endif // MLIR_DIALECT_EMITC_IR_EMITC diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -110,7 +110,6 @@ }]; let assemblyFormat = "attr-dict `:` type($result)"; - let verifier = [{ return success(); }]; } def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [NoSideEffect]>, @@ -126,7 +125,6 @@ }]; let assemblyFormat = "attr-dict `:` type($result)"; - let verifier = [{ return success(); }]; } def GPU_SubgroupSizeOp : GPU_Op<"subgroup_size", [NoSideEffect]>, @@ -142,7 +140,6 @@ }]; let assemblyFormat = "attr-dict `:` type($result)"; - let verifier = [{ return success(); }]; } def GPU_GPUFuncOp : GPU_Op<"func", [ @@ -298,7 +295,6 @@ LogicalResult verifyBody(); }]; - // let verifier = [{ return ::verifFuncOpy(*this); }]; let printer = [{ printGPUFuncOp(p, *this); }]; let parser = [{ return parseGPUFuncOp(parser, result); }]; } @@ -434,7 +430,6 @@ static StringRef getKernelAttrName() { return "kernel"; } }]; - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) $kernel @@ -443,6 +438,7 @@ (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)? custom($operands, type($operands)) attr-dict }]; + let hasVerifier = 1; } def GPU_LaunchOp : GPU_Op<"launch">, @@ -562,8 +558,8 @@ let parser = [{ return parseLaunchOp(parser, result); }]; let printer = [{ printLaunchOp(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def GPU_PrintfOp : GPU_Op<"printf", [MemoryEffects<[MemWrite]>]>, @@ -595,7 +591,7 @@ let builders = [OpBuilder<(ins), [{ // empty}]>]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def GPU_TerminatorOp : GPU_Op<"terminator", [HasParent<"LaunchOp">, @@ -682,9 +678,9 @@ in convergence. }]; let regions = (region AnyRegion:$body); - let verifier = [{ return ::verifyAllReduce(*this); }]; let assemblyFormat = [{ custom($op) $value $body attr-dict `:` functional-type(operands, results) }]; + let hasVerifier = 1; } def GPU_ShuffleOpXor : I32EnumAttrCase<"XOR", 0, "xor">; @@ -822,7 +818,6 @@ }]; let assemblyFormat = "$value attr-dict `:` type($value)"; - let verifier = [{ return success(); }]; } def GPU_WaitOp : GPU_Op<"wait", [GPU_AsyncOpInterface]> { @@ -971,8 +966,8 @@ custom(type($asyncToken), $asyncDependencies) $dst`,` $src `:` type($dst)`,` type($src) attr-dict }]; - let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasVerifier = 1; } def GPU_MemsetOp : GPU_Op<"memset", @@ -1006,8 +1001,6 @@ custom(type($asyncToken), $asyncDependencies) $dst`,` $value `:` type($dst)`,` type($value) attr-dict }]; - // MemsetOp is fully verified by traits. - let verifier = [{ return success(); }]; let hasFolder = 1; } @@ -1048,8 +1041,7 @@ let assemblyFormat = [{ $srcMemref`[`$indices`]` attr-dict `:` type($srcMemref) `->` type($res) }]; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix", @@ -1086,8 +1078,7 @@ let assemblyFormat = [{ $src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref) }]; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", @@ -1125,8 +1116,7 @@ let assemblyFormat = [{ $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB) `->` type($res) }]; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix", diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -25,12 +25,10 @@ Op { // For every linalg op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -54,8 +52,6 @@ `:` type($result) }]; - let verifier = [{ return ::verify(*this); }]; - let extraClassDeclaration = [{ static StringRef getStaticSizesAttrName() { return "static_sizes"; @@ -127,6 +123,7 @@ ]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, @@ -144,6 +141,7 @@ ``` }]; let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + let hasVerifier = 1; } def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [ @@ -426,6 +424,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Linalg_IndexOp : Linalg_Op<"index", [NoSideEffect]>, @@ -469,6 +468,7 @@ }]; let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; + let hasVerifier = 1; } #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -101,10 +101,9 @@ OpBuilder<(ins "Value":$value, "Value":$output)> ]; - let verifier = [{ return ::verify(*this); }]; - let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -264,10 +263,9 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseGenericOp(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; - let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -37,7 +37,6 @@ Op { let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -153,8 +152,6 @@ /// The i-th data operand passed. Value getDataOperand(unsigned i); }]; - - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -225,6 +222,7 @@ ( `attach` `(` $attachOperands^ `:` type($attachOperands) `)` )? $region attr-dict-with-keyword }]; + let hasVerifier = 1; } def OpenACC_TerminatorOp : OpenACC_Op<"terminator", [Terminator]> { @@ -237,8 +235,6 @@ to the enclosing op. }]; - let verifier = ?; - let assemblyFormat = "attr-dict"; } @@ -292,6 +288,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -342,6 +339,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -406,8 +404,7 @@ static StringRef getPrivateKeyword() { return "private"; } static StringRef getReductionKeyword() { return "reduction"; } }]; - - let verifier = [{ return ::verifyLoopOp(*this); }]; + let hasVerifier = 1; } // Yield operation for the acc.loop and acc.parallel operations. @@ -425,8 +422,6 @@ let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - let verifier = ?; - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } @@ -458,6 +453,7 @@ ( `device_num` `(` $deviceNumOperand^ `:` type($deviceNumOperand) `)` )? ( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -488,6 +484,7 @@ ( `device_num` `(` $deviceNumOperand^ `:` type($deviceNumOperand) `)` )? ( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -542,6 +539,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -575,6 +573,7 @@ ( `wait_devnum` `(` $waitDevnum^ `:` type($waitDevnum) `)` )? ( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword }]; + let hasVerifier = 1; } #endif // OPENACC_OPS diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -128,7 +128,7 @@ ]; let parser = [{ return parseParallelOp(parser, result); }]; let printer = [{ return printParallelOp(p, *this); }]; - let verifier = [{ return ::verifyParallelOp(*this); }]; + let hasVerifier = 1; } def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> { @@ -217,7 +217,7 @@ let parser = [{ return parseSectionsOp(parser, result); }]; let printer = [{ return printSectionsOp(p, *this); }]; - let verifier = [{ return verifySectionsOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -336,7 +336,7 @@ }]; let parser = [{ return parseWsLoopOp(parser, result); }]; let printer = [{ return printWsLoopOp(p, *this); }]; - let verifier = [{ return ::verifyWsLoopOp(*this); }]; + let hasVerifier = 1; } def YieldOp : OpenMP_Op<"yield", @@ -457,8 +457,7 @@ let assemblyFormat = [{ $sym_name custom($hint) attr-dict }]; - - let verifier = "return verifyCriticalDeclareOp(*this);"; + let hasVerifier = 1; } @@ -476,8 +475,7 @@ let assemblyFormat = [{ (`(` $name^ `)`)? $region attr-dict }]; - - let verifier = "return ::verifyCriticalOp(*this);"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -540,8 +538,7 @@ ( `depend_vec` `(` $depend_vec_vars^ `:` type($depend_vec_vars) `)` )? attr-dict }]; - - let verifier = "return ::verifyOrderedOp(*this);"; + let hasVerifier = 1; } def OrderedRegionOp : OpenMP_Op<"ordered_region"> { @@ -561,8 +558,7 @@ let regions = (region AnyRegion:$region); let assemblyFormat = [{ ( `simd` $simd^ )? $region attr-dict}]; - - let verifier = "return ::verifyOrderedRegionOp(*this);"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -614,7 +610,7 @@ OptionalAttr:$memory_order); let parser = [{ return parseAtomicReadOp(parser, result); }]; let printer = [{ return printAtomicReadOp(p, *this); }]; - let verifier = [{ return verifyAtomicReadOp(*this); }]; + let hasVerifier = 1; } def AtomicWriteOp : OpenMP_Op<"atomic.write"> { @@ -643,7 +639,7 @@ OptionalAttr:$memory_order); let parser = [{ return parseAtomicWriteOp(parser, result); }]; let printer = [{ return printAtomicWriteOp(p, *this); }]; - let verifier = [{ return verifyAtomicWriteOp(*this); }]; + let hasVerifier = 1; } // TODO: autogenerate from OMP.td in future if possible. @@ -708,7 +704,7 @@ OptionalAttr:$memory_order); let parser = [{ return parseAtomicUpdateOp(parser, result); }]; let printer = [{ return printAtomicUpdateOp(p, *this); }]; - let verifier = [{ return verifyAtomicUpdateOp(*this); }]; + let hasVerifier = 1; } def AtomicCaptureOp : OpenMP_Op<"atomic.capture", @@ -752,7 +748,7 @@ let regions = (region SizedRegion<1>:$region); let parser = [{ return parseAtomicCaptureOp(parser, result); }]; let printer = [{ return printAtomicCaptureOp(p, *this); }]; - let verifier = [{ return verifyAtomicCaptureOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -789,7 +785,6 @@ let regions = (region AnyRegion:$initializerRegion, AnyRegion:$reductionRegion, AnyRegion:$atomicReductionRegion); - let verifier = "return ::verifyReductionDeclareOp(*this);"; let assemblyFormat = "$sym_name `:` $type attr-dict-with-keyword " "`init` $initializerRegion " @@ -804,6 +799,7 @@ return atomicReductionRegion().front().getArgument(0).getType(); } }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -826,7 +822,7 @@ let arguments= (ins AnyType:$operand, OpenMP_PointerLikeType:$accumulator); let assemblyFormat = "$operand `,` $accumulator attr-dict `:` type($accumulator)"; - let verifier = "return ::verifyReductionOp(*this);"; + let hasVerifier = 1; } #endif // OPENMP_OPS diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -26,7 +26,6 @@ : Op { let printer = [{ ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; } //===----------------------------------------------------------------------===// @@ -66,6 +65,7 @@ params.empty() ? ArrayAttr() : $_builder.getArrayAttr(params)); }]>, ]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -118,6 +118,7 @@ $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -164,6 +165,7 @@ build($_builder, $_state, $_builder.getType(), Value(), attr); }]>, ]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -185,7 +187,6 @@ }]; let arguments = (ins PDL_Operation:$operation); let assemblyFormat = "$operation attr-dict"; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -224,6 +225,7 @@ build($_builder, $_state, $_builder.getType(), Value()); }]>, ]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -263,6 +265,7 @@ Value()); }]>, ]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -396,6 +399,7 @@ /// inference. bool hasTypeInference(); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -452,6 +456,7 @@ /// Returns the rewrite operation of this pattern. RewriteOp getRewriter(); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -492,6 +497,7 @@ $operation `with` (`(` $replValues^ `:` type($replValues) `)`)? ($replOperation^)? attr-dict }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -524,7 +530,6 @@ let arguments = (ins PDL_Operation:$parent, I32Attr:$index); let results = (outs PDL_Value:$val); let assemblyFormat = "$index `of` $parent attr-dict"; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -567,6 +572,7 @@ ($index^)? `of` $parent custom(ref($index), type($val)) attr-dict }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -629,6 +635,7 @@ ($body^)? attr-dict-with-keyword }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -657,6 +664,7 @@ let arguments = (ins OptionalAttr:$type); let results = (outs PDL_Type:$result); let assemblyFormat = "attr-dict (`:` $type^)?"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -685,6 +693,7 @@ let arguments = (ins OptionalAttr:$types); let results = (outs PDL_RangeOf:$result); let assemblyFormat = "attr-dict (`:` $types^)?"; + let hasVerifier = 1; } #endif // MLIR_DIALECT_PDL_IR_PDLOPS diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -75,19 +75,6 @@ PDLInterp_Op { let successors = (successor AnySuccessor:$defaultDest, VariadicSuccessor:$cases); - - let verifier = [{ - // Verify that the number of case destinations matches the number of case - // values. - size_t numDests = cases().size(); - size_t numValues = caseValues().size(); - if (numDests != numValues) { - return emitOpError("expected number of cases to match the number of case " - "values, got ") - << numDests << " but expected " << numValues; - } - return success(); - }]; } //===----------------------------------------------------------------------===// @@ -638,7 +625,7 @@ }]; let parser = [{ return ::parseForEachOp(parser, result); }]; let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1078,6 +1065,7 @@ build($_builder, $_state, attribute, $_builder.getArrayAttr(caseValues), defaultDest, dests); }]>]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1111,6 +1099,7 @@ build($_builder, $_state, operation, $_builder.getI32VectorAttr(counts), defaultDest, dests); }]>]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1148,6 +1137,7 @@ defaultDest, dests); }]>, ]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1181,6 +1171,7 @@ build($_builder, $_state, operation, $_builder.getI32VectorAttr(counts), defaultDest, dests); }]>]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1218,6 +1209,7 @@ let extraClassDeclaration = [{ auto getCaseTypes() { return caseValues().getAsValueRange(); } }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1259,6 +1251,7 @@ let extraClassDeclaration = [{ auto getCaseTypes() { return caseValues().getAsRange(); } }]; + let hasVerifier = 1; } #endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/QuantOps.td --- a/mlir/include/mlir/Dialect/Quant/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/QuantOps.td @@ -103,7 +103,7 @@ StrAttr:$logical_kernel); let results = (outs Variadic:$outputs); let regions = (region SizedRegion<1>:$body); - let verifier = [{ return verifyRegionOp(*this); }]; + let hasVerifier = 1; } def quant_ReturnOp : quant_Op<"return", [Terminator]> { @@ -227,43 +227,7 @@ OptionalAttr:$axisStats, OptionalAttr:$axis); let results = (outs quant_RealValueType); - - let verifier = [{ - auto tensorArg = arg().getType().dyn_cast(); - if (!tensorArg) return emitOpError("arg needs to be tensor type."); - - // Verify layerStats attribute. - { - auto layerStatsType = layerStats().getType(); - if (!layerStatsType.getElementType().isa()) { - return emitOpError( - "layerStats must have a floating point element type"); - } - if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { - return emitOpError("layerStats must have shape [2]"); - } - } - // Verify axisStats (optional) attribute. - if (axisStats()) { - if (!axis()) return emitOpError("axis must be specified for axisStats"); - - auto shape = tensorArg.getShape(); - auto argSliceSize = std::accumulate(std::next(shape.begin(), - *axis()), shape.end(), 1, std::multiplies()); - - auto axisStatsType = axisStats()->getType(); - if (!axisStatsType.getElementType().isa()) { - return emitOpError("axisStats must have a floating point element type"); - } - if (axisStatsType.getRank() != 2 || - axisStatsType.getDimSize(1) != 2 || - axisStatsType.getDimSize(0) != argSliceSize) { - return emitOpError("axisStats must have shape [N,2] " - "where N = the slice size defined by the axis dim"); - } - } - return success(); - }]; + let hasVerifier = 1; } def quant_CoupledRefOp : quant_Op<"coupled_ref", [SameOperandsAndResultType]> { diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -29,12 +29,10 @@ Op { // For every standard op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -56,9 +54,6 @@ let assemblyFormat = [{ `(` $condition `)` attr-dict ($args^ `:` type($args))? }]; - - // Override the default verifier, everything is checked by traits. - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -114,6 +109,7 @@ let hasCanonicalizer = 1; let hasFolder = 0; + let hasVerifier = 1; } def ForOp : SCF_Op<"for", @@ -312,6 +308,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def IfOp : SCF_Op<"if", @@ -404,6 +401,7 @@ }]; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def ParallelOp : SCF_Op<"parallel", @@ -485,6 +483,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> { @@ -533,6 +532,7 @@ let arguments = (ins AnyType:$operand); let regions = (region SizedRegion<1>:$reductionOperator); + let hasVerifier = 1; } def ReduceReturnOp : @@ -551,6 +551,7 @@ let arguments = (ins AnyType:$result); let assemblyFormat = "$result attr-dict `:` type($result)"; + let hasVerifier = 1; } def WhileOp : SCF_Op<"while", @@ -683,6 +684,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator, @@ -706,10 +708,6 @@ let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }]; - - // Override default verifier (defined in SCF_Op), no custom verification - // needed. - let verifier = ?; } #endif // MLIR_DIALECT_SCF_SCFOPS diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -48,8 +48,6 @@ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; - let verifier = [{ return verifySizeOrIndexOp(*this); }]; - let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by // InferTypeOpInterface @@ -57,6 +55,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> { @@ -102,7 +101,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Shape_ConstShapeOp : Shape_Op<"const_shape", @@ -184,8 +183,8 @@ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -323,7 +322,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -377,7 +376,7 @@ }]; let hasFolder = 1; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let hasVerifier = 1; } def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> { @@ -518,8 +517,8 @@ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -545,7 +544,7 @@ let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)"; let hasFolder = 1; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by // InferTypeOpInterface @@ -595,7 +594,7 @@ let builders = [OpBuilder<(ins "Value":$shape, "ValueRange":$initVals)>]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -614,9 +613,9 @@ let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; - let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -724,8 +723,8 @@ [{ build($_builder, $_state, llvm::None); }]> ]; - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let hasVerifier = 1; } // TODO: Add Ops: if_static, if_ranked @@ -859,8 +858,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Shape_AssumingOp : Shape_Op<"assuming", [ @@ -882,7 +880,6 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; let extraClassDeclaration = [{ // Inline the region into the region containing the AssumingOp and delete @@ -898,6 +895,7 @@ ]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", @@ -959,7 +957,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> { diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -35,12 +35,10 @@ Op { // For every standard op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -66,10 +64,6 @@ let arguments = (ins I1:$arg, StrAttr:$msg); let assemblyFormat = "$arg `,` $msg attr-dict"; - - // AssertOp is fully verified by its traits. - let verifier = ?; - let hasCanonicalizeMethod = 1; } @@ -107,9 +101,6 @@ $_state.addOperands(destOperands); }]>]; - // BranchOp is fully verified by traits. - let verifier = ?; - let extraClassDeclaration = [{ void setDest(Block *block); @@ -189,7 +180,6 @@ let assemblyFormat = [{ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) }]; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -250,7 +240,6 @@ CallInterfaceCallable getCallableForCallee() { return getCallee(); } }]; - let verifier = ?; let hasCanonicalizeMethod = 1; let assemblyFormat = @@ -311,9 +300,6 @@ falseOperands); }]>]; - // CondBranchOp is fully verified by traits. - let verifier = ?; - let extraClassDeclaration = [{ // These are the indices into the dests list. enum { trueIndex = 0, falseIndex = 1 }; @@ -434,6 +420,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -466,6 +453,7 @@ [{ build($_builder, $_state, llvm::None); }]>]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -516,6 +504,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -562,6 +551,7 @@ [{ build($_builder, $_state, aggregateType, element); }]>]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$input attr-dict `:` type($aggregate)"; } @@ -652,6 +642,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } #endif // STANDARD_OPS diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -82,8 +82,7 @@ ); let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; - - let verifier = [{ return verifyAveragePoolOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -116,8 +115,7 @@ ); let builders = [Tosa_ConvOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -149,8 +147,7 @@ ); let builders = [Tosa_ConvOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -183,8 +180,7 @@ ); let builders = [Tosa_ConvOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -212,8 +208,7 @@ ); let builders = [Tosa_FCOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -157,7 +157,7 @@ }]; let parser = [{ return ::parseFuncOp(parser, result); }]; let printer = [{ return ::print(*this, p); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -221,7 +221,7 @@ return "builtin"; } }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; // We need to ensure the block inside the region is properly terminated; // the auto-generated builders do not guarantee that. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2470,12 +2470,10 @@ // For every such op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; code extraBaseClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -524,18 +524,18 @@ p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"map"}); } -static LogicalResult verify(AffineApplyOp op) { +LogicalResult AffineApplyOp::verify() { // Check input and output dimensions match. - auto map = op.map(); + AffineMap affineMap = map(); // Verify that operand count matches affine map dimension and symbol count. - if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols()) - return op.emitOpError( + if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols()) + return emitOpError( "operand count and affine map dimension and symbol count must match"); // Verify that the map only produces one result. - if (map.getNumResults() != 1) - return op.emitOpError("mapping must produce one value"); + if (affineMap.getNumResults() != 1) + return emitOpError("mapping must produce one value"); return success(); } @@ -1306,41 +1306,38 @@ bodyBuilder); } -static LogicalResult verify(AffineForOp op) { +LogicalResult AffineForOp::verify() { // Check that the body defines as single block argument for the induction // variable. - auto *body = op.getBody(); + auto *body = getBody(); if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex()) - return op.emitOpError( - "expected body to have a single index argument for the " - "induction variable"); + return emitOpError("expected body to have a single index argument for the " + "induction variable"); // Verify that the bound operands are valid dimension/symbols. /// Lower bound. - if (op.getLowerBoundMap().getNumInputs() > 0) - if (failed( - verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), - op.getLowerBoundMap().getNumDims()))) + if (getLowerBoundMap().getNumInputs() > 0) + if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(), + getLowerBoundMap().getNumDims()))) return failure(); /// Upper bound. - if (op.getUpperBoundMap().getNumInputs() > 0) - if (failed( - verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), - op.getUpperBoundMap().getNumDims()))) + if (getUpperBoundMap().getNumInputs() > 0) + if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(), + getUpperBoundMap().getNumDims()))) return failure(); - unsigned opNumResults = op.getNumResults(); + unsigned opNumResults = getNumResults(); if (opNumResults == 0) return success(); // If ForOp defines values, check that the number and types of the defined // values match ForOp initial iter operands and backedge basic block // arguments. - if (op.getNumIterOperands() != opNumResults) - return op.emitOpError( + if (getNumIterOperands() != opNumResults) + return emitOpError( "mismatch between the number of loop-carried values and results"); - if (op.getNumRegionIterArgs() != opNumResults) - return op.emitOpError( + if (getNumRegionIterArgs() != opNumResults) + return emitOpError( "mismatch between the number of basic block args and results"); return success(); @@ -2063,23 +2060,22 @@ }; } // namespace -static LogicalResult verify(AffineIfOp op) { +LogicalResult AffineIfOp::verify() { // Verify that we have a condition attribute. + // FIXME: This should be specified in the arguments list in ODS. auto conditionAttr = - op->getAttrOfType(op.getConditionAttrName()); + (*this)->getAttrOfType(getConditionAttrName()); if (!conditionAttr) - return op.emitOpError( - "requires an integer set attribute named 'condition'"); + return emitOpError("requires an integer set attribute named 'condition'"); // Verify that there are enough operands for the condition. IntegerSet condition = conditionAttr.getValue(); - if (op.getNumOperands() != condition.getNumInputs()) - return op.emitOpError( - "operand count and condition integer set dimension and " - "symbol count must match"); + if (getNumOperands() != condition.getNumInputs()) + return emitOpError("operand count and condition integer set dimension and " + "symbol count must match"); // Verify that the operands are valid dimension/symbols. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(), + if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), condition.getNumDims()))) return failure(); @@ -2325,16 +2321,16 @@ return success(); } -LogicalResult verify(AffineLoadOp op) { - auto memrefType = op.getMemRefType(); - if (op.getType() != memrefType.getElementType()) - return op.emitOpError("result type must match element type of memref"); +LogicalResult AffineLoadOp::verify() { + auto memrefType = getMemRefType(); + if (getType() != memrefType.getElementType()) + return emitOpError("result type must match element type of memref"); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 1))) + getOperation(), + (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 1))) return failure(); return success(); @@ -2413,18 +2409,18 @@ p << " : " << op.getMemRefType(); } -LogicalResult verify(AffineStoreOp op) { +LogicalResult AffineStoreOp::verify() { // The value to store must have the same type as memref element type. - auto memrefType = op.getMemRefType(); - if (op.getValueToStore().getType() != memrefType.getElementType()) - return op.emitOpError( + auto memrefType = getMemRefType(); + if (getValueToStore().getType() != memrefType.getElementType()) + return emitOpError( "value to store must have the same type as memref element type"); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 2))) + getOperation(), + (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 2))) return failure(); return success(); @@ -2672,6 +2668,8 @@ context); } +LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); } + //===----------------------------------------------------------------------===// // AffineMaxOp //===----------------------------------------------------------------------===// @@ -2691,6 +2689,8 @@ context); } +LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); } + //===----------------------------------------------------------------------===// // AffinePrefetchOp //===----------------------------------------------------------------------===// @@ -2764,24 +2764,24 @@ p << " : " << op.getMemRefType(); } -static LogicalResult verify(AffinePrefetchOp op) { - auto mapAttr = op->getAttrOfType(op.getMapAttrName()); +LogicalResult AffinePrefetchOp::verify() { + auto mapAttr = (*this)->getAttrOfType(getMapAttrName()); if (mapAttr) { AffineMap map = mapAttr.getValue(); - if (map.getNumResults() != op.getMemRefType().getRank()) - return op.emitOpError("affine.prefetch affine map num results must equal" - " memref rank"); - if (map.getNumInputs() + 1 != op.getNumOperands()) - return op.emitOpError("too few operands"); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.prefetch affine map num results must equal" + " memref rank"); + if (map.getNumInputs() + 1 != getNumOperands()) + return emitOpError("too few operands"); } else { - if (op.getNumOperands() != 1) - return op.emitOpError("too few operands"); + if (getNumOperands() != 1) + return emitOpError("too few operands"); } - Region *scope = getAffineScope(op); - for (auto idx : op.getMapOperands()) { + Region *scope = getAffineScope(*this); + for (auto idx : getMapOperands()) { if (!isValidAffineIndexOperand(idx, scope)) - return op.emitOpError("index must be a dimension or symbol identifier"); + return emitOpError("index must be a dimension or symbol identifier"); } return success(); } @@ -3018,53 +3018,52 @@ stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); } -static LogicalResult verify(AffineParallelOp op) { - auto numDims = op.getNumDims(); - if (op.lowerBoundsGroups().getNumElements() != numDims || - op.upperBoundsGroups().getNumElements() != numDims || - op.steps().size() != numDims || - op.getBody()->getNumArguments() != numDims) { - return op.emitOpError() - << "the number of region arguments (" - << op.getBody()->getNumArguments() - << ") and the number of map groups for lower (" - << op.lowerBoundsGroups().getNumElements() << ") and upper bound (" - << op.upperBoundsGroups().getNumElements() - << "), and the number of steps (" << op.steps().size() - << ") must all match"; +LogicalResult AffineParallelOp::verify() { + auto numDims = getNumDims(); + if (lowerBoundsGroups().getNumElements() != numDims || + upperBoundsGroups().getNumElements() != numDims || + steps().size() != numDims || getBody()->getNumArguments() != numDims) { + return emitOpError() << "the number of region arguments (" + << getBody()->getNumArguments() + << ") and the number of map groups for lower (" + << lowerBoundsGroups().getNumElements() + << ") and upper bound (" + << upperBoundsGroups().getNumElements() + << "), and the number of steps (" << steps().size() + << ") must all match"; } unsigned expectedNumLBResults = 0; - for (APInt v : op.lowerBoundsGroups()) + for (APInt v : lowerBoundsGroups()) expectedNumLBResults += v.getZExtValue(); - if (expectedNumLBResults != op.lowerBoundsMap().getNumResults()) - return op.emitOpError() << "expected lower bounds map to have " - << expectedNumLBResults << " results"; + if (expectedNumLBResults != lowerBoundsMap().getNumResults()) + return emitOpError() << "expected lower bounds map to have " + << expectedNumLBResults << " results"; unsigned expectedNumUBResults = 0; - for (APInt v : op.upperBoundsGroups()) + for (APInt v : upperBoundsGroups()) expectedNumUBResults += v.getZExtValue(); - if (expectedNumUBResults != op.upperBoundsMap().getNumResults()) - return op.emitOpError() << "expected upper bounds map to have " - << expectedNumUBResults << " results"; + if (expectedNumUBResults != upperBoundsMap().getNumResults()) + return emitOpError() << "expected upper bounds map to have " + << expectedNumUBResults << " results"; - if (op.reductions().size() != op.getNumResults()) - return op.emitOpError("a reduction must be specified for each output"); + if (reductions().size() != getNumResults()) + return emitOpError("a reduction must be specified for each output"); // Verify reduction ops are all valid - for (Attribute attr : op.reductions()) { + for (Attribute attr : reductions()) { auto intAttr = attr.dyn_cast(); if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt())) - return op.emitOpError("invalid reduction attribute"); + return emitOpError("invalid reduction attribute"); } // Verify that the bound operands are valid dimension/symbols. /// Lower bounds. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(), - op.lowerBoundsMap().getNumDims()))) + if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(), + lowerBoundsMap().getNumDims()))) return failure(); /// Upper bounds. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(), - op.upperBoundsMap().getNumDims()))) + if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(), + upperBoundsMap().getNumDims()))) return failure(); return success(); } @@ -3412,20 +3411,19 @@ // AffineYieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AffineYieldOp op) { - auto *parentOp = op->getParentOp(); +LogicalResult AffineYieldOp::verify() { + auto *parentOp = (*this)->getParentOp(); auto results = parentOp->getResults(); - auto operands = op.getOperands(); + auto operands = getOperands(); if (!isa(parentOp)) - return op.emitOpError() << "only terminates affine.if/for/parallel regions"; - if (parentOp->getNumResults() != op.getNumOperands()) - return op.emitOpError() << "parent of yield must have same number of " - "results as the yield operands"; + return emitOpError() << "only terminates affine.if/for/parallel regions"; + if (parentOp->getNumResults() != getNumOperands()) + return emitOpError() << "parent of yield must have same number of " + "results as the yield operands"; for (auto it : llvm::zip(results, operands)) { if (std::get<0>(it).getType() != std::get<1>(it).getType()) - return op.emitOpError() - << "types mismatch between yield op and its parent"; + return emitOpError() << "types mismatch between yield op and its parent"; } return success(); @@ -3516,17 +3514,16 @@ return success(); } -static LogicalResult verify(AffineVectorLoadOp op) { - MemRefType memrefType = op.getMemRefType(); +LogicalResult AffineVectorLoadOp::verify() { + MemRefType memrefType = getMemRefType(); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 1))) + getOperation(), + (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 1))) return failure(); - if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, - op.getVectorType()))) + if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType()))) return failure(); return success(); @@ -3599,17 +3596,15 @@ p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType(); } -static LogicalResult verify(AffineVectorStoreOp op) { - MemRefType memrefType = op.getMemRefType(); +LogicalResult AffineVectorStoreOp::verify() { + MemRefType memrefType = getMemRefType(); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 2))) + *this, (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 2))) return failure(); - if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, - op.getVectorType()))) + if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType()))) return failure(); return success(); diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -107,19 +107,19 @@ /// TODO: disallow arith.constant to return anything other than signless integer /// or float like. -static LogicalResult verify(arith::ConstantOp op) { - auto type = op.getType(); +LogicalResult arith::ConstantOp::verify() { + auto type = getType(); // The value's type must match the return type. - if (op.getValue().getType() != type) { - return op.emitOpError() << "value type " << op.getValue().getType() - << " must match return type: " << type; + if (getValue().getType() != type) { + return emitOpError() << "value type " << getValue().getType() + << " must match return type: " << type; } // Integer values must be signless. if (type.isa() && !type.cast().isSignless()) - return op.emitOpError("integer return type must be signless"); + return emitOpError("integer return type must be signless"); // Any float or elements attribute are acceptable. - if (!op.getValue().isa()) { - return op.emitOpError( + if (!getValue().isa()) { + return emitOpError( "value must be an integer, float, or elements attribute"); } return success(); @@ -886,6 +886,10 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::ExtUIOp::verify() { + return verifyExtOp(*this); +} + //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// @@ -912,6 +916,10 @@ patterns.insert(context); } +LogicalResult arith::ExtSIOp::verify() { + return verifyExtOp(*this); +} + //===----------------------------------------------------------------------===// // ExtFOp //===----------------------------------------------------------------------===// @@ -920,6 +928,8 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::ExtFOp::verify() { return verifyExtOp(*this); } + //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// @@ -954,6 +964,10 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::TruncIOp::verify() { + return verifyTruncateOp(*this); +} + //===----------------------------------------------------------------------===// // TruncFOp //===----------------------------------------------------------------------===// @@ -983,6 +997,10 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::TruncFOp::verify() { + return verifyTruncateOp(*this); +} + //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -33,17 +33,17 @@ // YieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(YieldOp op) { +LogicalResult YieldOp::verify() { // Get the underlying value types from async values returned from the // parent `async.execute` operation. - auto executeOp = op->getParentOfType(); + auto executeOp = (*this)->getParentOfType(); auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { return result.getType().cast().getValueType(); }); - if (op.getOperandTypes() != types) - return op.emitOpError("operand types do not match the types returned from " - "the parent ExecuteOp"); + if (getOperandTypes() != types) + return emitOpError("operand types do not match the types returned from " + "the parent ExecuteOp"); return success(); } @@ -228,16 +228,16 @@ return success(); } -static LogicalResult verify(ExecuteOp op) { +LogicalResult ExecuteOp::verify() { // Unwrap async.execute value operands types. - auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { + auto unwrappedTypes = llvm::map_range(operands(), [](Value operand) { return operand.getType().cast().getValueType(); }); // Verify that unwrapped argument types matches the body region arguments. - if (op.body().getArgumentTypes() != unwrappedTypes) - return op.emitOpError("async body region argument types do not match the " - "execute operation arguments types"); + if (body().getArgumentTypes() != unwrappedTypes) + return emitOpError("async body region argument types do not match the " + "execute operation arguments types"); return success(); } @@ -303,19 +303,19 @@ p << operandType; } -static LogicalResult verify(AwaitOp op) { - Type argType = op.operand().getType(); +LogicalResult AwaitOp::verify() { + Type argType = operand().getType(); // Awaiting on a token does not have any results. - if (argType.isa() && !op.getResultTypes().empty()) - return op.emitOpError("awaiting on a token must have empty result"); + if (argType.isa() && !getResultTypes().empty()) + return emitOpError("awaiting on a token must have empty result"); // Awaiting on a value unwraps the async value type. if (auto value = argType.dyn_cast()) { - if (*op.getResultType() != value.getValueType()) - return op.emitOpError() - << "result type " << *op.getResultType() - << " does not match async value type " << value.getValueType(); + if (*getResultType() != value.getValueType()) + return emitOpError() << "result type " << *getResultType() + << " does not match async value type " + << value.getValueType(); } return success(); diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -38,18 +38,18 @@ return false; } -static LogicalResult verify(ConstantOp op) { - ArrayAttr arrayAttr = op.getValue(); +LogicalResult ConstantOp::verify() { + ArrayAttr arrayAttr = getValue(); if (arrayAttr.size() != 2) { - return op.emitOpError( + return emitOpError( "requires 'value' to be a complex constant, represented as array of " "two values"); } - auto complexEltTy = op.getType().getElementType(); + auto complexEltTy = getType().getElementType(); if (complexEltTy != arrayAttr[0].getType() || complexEltTy != arrayAttr[1].getType()) { - return op.emitOpError() + return emitOpError() << "requires attribute's element types (" << arrayAttr[0].getType() << ", " << arrayAttr[1].getType() << ") to match the element type of the op's return type (" diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -48,16 +48,16 @@ // ApplyOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ApplyOp op) { - StringRef applicableOperator = op.applicableOperator(); +LogicalResult ApplyOp::verify() { + StringRef applicableOperatorStr = applicableOperator(); // Applicable operator must not be empty. - if (applicableOperator.empty()) - return op.emitOpError("applicable operator must not be empty"); + if (applicableOperatorStr.empty()) + return emitOpError("applicable operator must not be empty"); // Only `*` and `&` are supported. - if (applicableOperator != "&" && applicableOperator != "*") - return op.emitOpError("applicable operator is illegal"); + if (applicableOperatorStr != "&" && applicableOperatorStr != "*") + return emitOpError("applicable operator is illegal"); return success(); } @@ -66,32 +66,32 @@ // CallOp //===----------------------------------------------------------------------===// -static LogicalResult verify(emitc::CallOp op) { +LogicalResult emitc::CallOp::verify() { // Callee must not be empty. - if (op.callee().empty()) - return op.emitOpError("callee must not be empty"); + if (callee().empty()) + return emitOpError("callee must not be empty"); - if (Optional argsAttr = op.args()) { + if (Optional argsAttr = args()) { for (Attribute arg : argsAttr.getValue()) { if (arg.getType().isa()) { int64_t index = arg.cast().getInt(); // Args with elements of type index must be in range // [0..operands.size). - if ((index < 0) || (index >= static_cast(op.getNumOperands()))) - return op.emitOpError("index argument is out of range"); + if ((index < 0) || (index >= static_cast(getNumOperands()))) + return emitOpError("index argument is out of range"); // Args with elements of type ArrayAttr must have a type. } else if (arg.isa() && arg.getType().isa()) { - return op.emitOpError("array argument has no type"); + return emitOpError("array argument has no type"); } } } - if (Optional templateArgsAttr = op.template_args()) { + if (Optional templateArgsAttr = template_args()) { for (Attribute tArg : templateArgsAttr.getValue()) { if (!tArg.isa() && !tArg.isa() && !tArg.isa() && !tArg.isa()) - return op.emitOpError("template argument has invalid type"); + return emitOpError("template argument has invalid type"); } } @@ -103,12 +103,12 @@ //===----------------------------------------------------------------------===// /// The constant op requires that the attribute's type matches the return type. -static LogicalResult verify(emitc::ConstantOp &op) { - Attribute value = op.value(); - Type type = op.getType(); +LogicalResult emitc::ConstantOp::verify() { + Attribute value = valueAttr(); + Type type = getType(); if (!value.getType().isa() && type != value.getType()) - return op.emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; + return emitOpError() << "requires attribute's type (" << value.getType() + << ") to match op's return type (" << type << ")"; return success(); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -270,46 +270,37 @@ return walkResult.wasInterrupted() ? failure() : success(); } -template -static LogicalResult verifyIndexOp(T op) { - auto dimension = op.dimension(); - if (dimension != "x" && dimension != "y" && dimension != "z") - return op.emitError("dimension \"") << dimension << "\" is invalid"; - return success(); -} - -static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { - if (allReduce.body().empty() != allReduce.op().hasValue()) - return allReduce.emitError( - "expected either an op attribute or a non-empty body"); - if (!allReduce.body().empty()) { - if (allReduce.body().getNumArguments() != 2) - return allReduce.emitError("expected two region arguments"); - for (auto argument : allReduce.body().getArguments()) { - if (argument.getType() != allReduce.getType()) - return allReduce.emitError("incorrect region argument type"); +LogicalResult gpu::AllReduceOp::verify() { + if (body().empty() != op().hasValue()) + return emitError("expected either an op attribute or a non-empty body"); + if (!body().empty()) { + if (body().getNumArguments() != 2) + return emitError("expected two region arguments"); + for (auto argument : body().getArguments()) { + if (argument.getType() != getType()) + return emitError("incorrect region argument type"); } unsigned yieldCount = 0; - for (Block &block : allReduce.body()) { + for (Block &block : body()) { if (auto yield = dyn_cast(block.getTerminator())) { if (yield.getNumOperands() != 1) - return allReduce.emitError("expected one gpu.yield operand"); - if (yield.getOperand(0).getType() != allReduce.getType()) - return allReduce.emitError("incorrect gpu.yield type"); + return emitError("expected one gpu.yield operand"); + if (yield.getOperand(0).getType() != getType()) + return emitError("incorrect gpu.yield type"); ++yieldCount; } } if (yieldCount == 0) - return allReduce.emitError("expected gpu.yield op in region"); + return emitError("expected gpu.yield op in region"); } else { - gpu::AllReduceOperation opName = *allReduce.op(); + gpu::AllReduceOperation opName = *op(); if ((opName == gpu::AllReduceOperation::AND || opName == gpu::AllReduceOperation::OR || opName == gpu::AllReduceOperation::XOR) && - !allReduce.getType().isa()) { - return allReduce.emitError() - << '`' << gpu::stringifyAllReduceOperation(opName) << '`' - << " accumulator is only compatible with Integer type"; + !getType().isa()) { + return emitError() + << '`' << gpu::stringifyAllReduceOperation(opName) + << "` accumulator is only compatible with Integer type"; } } return success(); @@ -411,20 +402,20 @@ return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; } -static LogicalResult verify(LaunchOp op) { +LogicalResult LaunchOp::verify() { // Kernel launch takes kNumConfigOperands leading operands for grid/block // sizes and transforms them into kNumConfigRegionAttributes region arguments // for block/thread identifiers and grid/block sizes. - if (!op.body().empty()) { - if (op.body().getNumArguments() != - LaunchOp::kNumConfigOperands + op.getNumOperands() - - (op.dynamicSharedMemorySize() ? 1 : 0)) - return op.emitOpError("unexpected number of region arguments"); + if (!body().empty()) { + if (body().getNumArguments() != LaunchOp::kNumConfigOperands + + getNumOperands() - + (dynamicSharedMemorySize() ? 1 : 0)) + return emitOpError("unexpected number of region arguments"); } // Block terminators without successors are expected to exit the kernel region // and must be `gpu.terminator`. - for (Block &block : op.body()) { + for (Block &block : body()) { if (block.empty()) continue; if (block.back().getNumSuccessors() != 0) @@ -434,7 +425,7 @@ .emitError() .append("expected '", gpu::TerminatorOp::getOperationName(), "' or a terminator with successors") - .attachNote(op.getLoc()) + .attachNote(getLoc()) .append("in '", LaunchOp::getOperationName(), "' body region"); } } @@ -650,21 +641,21 @@ return KernelDim3{operands[3], operands[4], operands[5]}; } -static LogicalResult verify(LaunchFuncOp op) { - auto module = op->getParentOfType(); +LogicalResult LaunchFuncOp::verify() { + auto module = (*this)->getParentOfType(); if (!module) - return op.emitOpError("expected to belong to a module"); + return emitOpError("expected to belong to a module"); if (!module->getAttrOfType( GPUDialect::getContainerModuleAttrName())) - return op.emitOpError( - "expected the closest surrounding module to have the '" + - GPUDialect::getContainerModuleAttrName() + "' attribute"); + return emitOpError("expected the closest surrounding module to have the '" + + GPUDialect::getContainerModuleAttrName() + + "' attribute"); - auto kernelAttr = op->getAttrOfType(op.getKernelAttrName()); + auto kernelAttr = (*this)->getAttrOfType(getKernelAttrName()); if (!kernelAttr) - return op.emitOpError("symbol reference attribute '" + - op.getKernelAttrName() + "' must be specified"); + return emitOpError("symbol reference attribute '" + getKernelAttrName() + + "' must be specified"); return success(); } @@ -945,25 +936,25 @@ // ReturnOp //===----------------------------------------------------------------------===// -static LogicalResult verify(gpu::ReturnOp returnOp) { - GPUFuncOp function = returnOp->getParentOfType(); +LogicalResult gpu::ReturnOp::verify() { + GPUFuncOp function = (*this)->getParentOfType(); FunctionType funType = function.getType(); - if (funType.getNumResults() != returnOp.operands().size()) - return returnOp.emitOpError() + if (funType.getNumResults() != operands().size()) + return emitOpError() .append("expected ", funType.getNumResults(), " result operands") .attachNote(function.getLoc()) .append("return type declared here"); for (const auto &pair : llvm::enumerate( - llvm::zip(function.getType().getResults(), returnOp.operands()))) { + llvm::zip(function.getType().getResults(), operands()))) { Type type; Value operand; std::tie(type, operand) = pair.value(); if (type != operand.getType()) - return returnOp.emitOpError() << "unexpected type `" << operand.getType() - << "' for operand #" << pair.index(); + return emitOpError() << "unexpected type `" << operand.getType() + << "' for operand #" << pair.index(); } return success(); } @@ -1014,15 +1005,15 @@ // GPUMemcpyOp //===----------------------------------------------------------------------===// -static LogicalResult verify(MemcpyOp op) { - auto srcType = op.src().getType(); - auto dstType = op.dst().getType(); +LogicalResult MemcpyOp::verify() { + auto srcType = src().getType(); + auto dstType = dst().getType(); if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType)) - return op.emitOpError("arguments have incompatible element type"); + return emitOpError("arguments have incompatible element type"); if (failed(verifyCompatibleShape(srcType, dstType))) - return op.emitOpError("arguments have incompatible shape"); + return emitOpError("arguments have incompatible shape"); return success(); } @@ -1056,26 +1047,26 @@ // GPU_SubgroupMmaLoadMatrixOp //===----------------------------------------------------------------------===// -static LogicalResult verify(SubgroupMmaLoadMatrixOp op) { - auto srcType = op.srcMemref().getType(); - auto resType = op.res().getType(); +LogicalResult SubgroupMmaLoadMatrixOp::verify() { + auto srcType = srcMemref().getType(); + auto resType = res().getType(); auto resMatrixType = resType.cast(); auto operand = resMatrixType.getOperand(); auto srcMemrefType = srcType.cast(); auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt(); if (!srcMemrefType.getLayout().isIdentity()) - return op.emitError("expected identity layout map for source memref"); + return emitError("expected identity layout map for source memref"); if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace && srcMemSpace != kGlobalMemorySpace) - return op.emitError( + return emitError( "source memorySpace kGenericMemorySpace, kSharedMemorySpace or " "kGlobalMemorySpace only allowed"); if (!operand.equals("AOp") && !operand.equals("BOp") && !operand.equals("COp")) - return op.emitError("only AOp, BOp and COp can be loaded"); + return emitError("only AOp, BOp and COp can be loaded"); return success(); } @@ -1084,23 +1075,22 @@ // GPU_SubgroupMmaStoreMatrixOp //===----------------------------------------------------------------------===// -static LogicalResult verify(SubgroupMmaStoreMatrixOp op) { - auto srcType = op.src().getType(); - auto dstType = op.dstMemref().getType(); +LogicalResult SubgroupMmaStoreMatrixOp::verify() { + auto srcType = src().getType(); + auto dstType = dstMemref().getType(); auto srcMatrixType = srcType.cast(); auto dstMemrefType = dstType.cast(); auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt(); if (!dstMemrefType.getLayout().isIdentity()) - return op.emitError("expected identity layout map for destination memref"); + return emitError("expected identity layout map for destination memref"); if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace && dstMemSpace != kGlobalMemorySpace) - return op.emitError( - "destination memorySpace of kGenericMemorySpace, " - "kGlobalMemorySpace or kSharedMemorySpace only allowed"); + return emitError("destination memorySpace of kGenericMemorySpace, " + "kGlobalMemorySpace or kSharedMemorySpace only allowed"); if (!srcMatrixType.getOperand().equals("COp")) - return op.emitError( + return emitError( "expected the operand matrix being stored to have 'COp' operand type"); return success(); @@ -1110,21 +1100,17 @@ // GPU_SubgroupMmaComputeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(SubgroupMmaComputeOp op) { +LogicalResult SubgroupMmaComputeOp::verify() { enum OperandMap { A, B, C }; SmallVector opTypes; - - auto populateOpInfo = [&opTypes, &op]() { - opTypes.push_back(op.opA().getType().cast()); - opTypes.push_back(op.opB().getType().cast()); - opTypes.push_back(op.opC().getType().cast()); - }; - populateOpInfo(); + opTypes.push_back(opA().getType().cast()); + opTypes.push_back(opB().getType().cast()); + opTypes.push_back(opC().getType().cast()); if (!opTypes[A].getOperand().equals("AOp") || !opTypes[B].getOperand().equals("BOp") || !opTypes[C].getOperand().equals("COp")) - return op.emitError("operands must be in the order AOp, BOp, COp"); + return emitError("operands must be in the order AOp, BOp, COp"); ArrayRef aShape, bShape, cShape; aShape = opTypes[A].getShape(); @@ -1133,7 +1119,7 @@ if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || bShape[1] != cShape[1]) - return op.emitError("operand shapes do not satisfy matmul constraints"); + return emitError("operand shapes do not satisfy matmul constraints"); return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -400,11 +400,11 @@ /// FillOp region is elided when printing. void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} -static LogicalResult verify(FillOp op) { - OpOperand *output = op.getOutputOperand(0); - Type fillType = op.value().getType(); +LogicalResult FillOp::verify() { + OpOperand *output = getOutputOperand(0); + Type fillType = value().getType(); if (getElementTypeOrSelf(output->get()) != fillType) - return op.emitOpError("expects fill type to match view elemental type"); + return emitOpError("expects fill type to match view elemental type"); return success(); } @@ -635,7 +635,7 @@ return success(); } -static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } +LogicalResult GenericOp::verify() { return verifyGenericOp(*this); } namespace { // Deduplicate redundant args of a linalg generic op. @@ -811,25 +811,24 @@ result.addAttributes(attrs); } -static LogicalResult verify(InitTensorOp op) { - RankedTensorType resultType = op.getType(); +LogicalResult InitTensorOp::verify() { + RankedTensorType resultType = getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( - op.static_sizes().cast(), + static_sizes().cast(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); - if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(), - op.static_sizes(), op.sizes(), - ShapedType::isDynamic))) + if (failed(verifyListOfOperandsOrIntegers( + *this, "sizes", resultType.getRank(), static_sizes(), sizes(), + ShapedType::isDynamic))) return failure(); - if (op.static_sizes().size() != static_cast(resultType.getRank())) - return op->emitError("expected ") - << resultType.getRank() << " sizes values"; + if (static_sizes().size() != static_cast(resultType.getRank())) + return emitError("expected ") << resultType.getRank() << " sizes values"; Type expectedType = InitTensorOp::inferResultType( staticSizes, resultType.getElementType(), resultType.getEncoding()); if (resultType != expectedType) { - return op.emitError("specified type ") + return emitError("specified type ") << resultType << " does not match the inferred type " << expectedType; } @@ -1030,13 +1029,13 @@ return success(); } -static LogicalResult verify(linalg::YieldOp op) { - auto *parentOp = op->getParentOp(); +LogicalResult linalg::YieldOp::verify() { + auto *parentOp = (*this)->getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) - return op.emitOpError("expected single non-empty parent region"); + return emitOpError("expected single non-empty parent region"); if (auto linalgOp = dyn_cast(parentOp)) - return verifyYield(op, cast(parentOp)); + return verifyYield(*this, cast(parentOp)); if (auto tiledLoopOp = dyn_cast(parentOp)) { // Check if output args with tensor types match results types. @@ -1044,25 +1043,25 @@ llvm::copy_if( tiledLoopOp.outputs(), std::back_inserter(tensorOuts), [&](Value out) { return out.getType().isa(); }); - if (tensorOuts.size() != op.values().size()) - return op.emitOpError("expected number of tensor output args = ") - << tensorOuts.size() << " to match the number of yield operands = " - << op.values().size(); + if (tensorOuts.size() != values().size()) + return emitOpError("expected number of tensor output args = ") + << tensorOuts.size() + << " to match the number of yield operands = " << values().size(); TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts)); for (auto &item : - llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) { + llvm::enumerate(llvm::zip(tensorTypes, getOperandTypes()))) { Type outType, resultType; unsigned index = item.index(); std::tie(outType, resultType) = item.value(); if (outType != resultType) - return op.emitOpError("expected yield operand ") + return emitOpError("expected yield operand ") << index << " with type = " << resultType << " to match output arg type = " << outType; } return success(); } - return op.emitOpError("expected parent op with LinalgOp interface"); + return emitOpError("expected parent op with LinalgOp interface"); } //===----------------------------------------------------------------------===// @@ -1316,37 +1315,37 @@ return !region().isAncestor(value.getParentRegion()); } -static LogicalResult verify(TiledLoopOp op) { +LogicalResult TiledLoopOp::verify() { // Check if iterator types are provided for every loop dimension. - if (op.iterator_types().size() != op.getNumLoops()) - return op.emitOpError("expected iterator types array attribute size = ") - << op.iterator_types().size() - << " to match the number of loops = " << op.getNumLoops(); + if (iterator_types().size() != getNumLoops()) + return emitOpError("expected iterator types array attribute size = ") + << iterator_types().size() + << " to match the number of loops = " << getNumLoops(); // Check if types of input arguments match region args types. for (auto &item : - llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) { + llvm::enumerate(llvm::zip(inputs(), getRegionInputArgs()))) { Value input, inputRegionArg; unsigned index = item.index(); std::tie(input, inputRegionArg) = item.value(); if (input.getType() != inputRegionArg.getType()) - return op.emitOpError("expected input arg ") + return emitOpError("expected input arg ") << index << " with type = " << input.getType() - << " to match region arg " << index + op.getNumLoops() + << " to match region arg " << index + getNumLoops() << " type = " << inputRegionArg.getType(); } // Check if types of input arguments match region args types. for (auto &item : - llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) { + llvm::enumerate(llvm::zip(outputs(), getRegionOutputArgs()))) { Value output, outputRegionArg; unsigned index = item.index(); std::tie(output, outputRegionArg) = item.value(); if (output.getType() != outputRegionArg.getType()) - return op.emitOpError("expected output arg ") + return emitOpError("expected output arg ") << index << " with type = " << output.getType() << " to match region arg " - << index + op.getNumLoops() + op.inputs().size() + << index + getNumLoops() + inputs().size() << " type = " << outputRegionArg.getType(); } return success(); @@ -1667,13 +1666,13 @@ // IndexOp //===----------------------------------------------------------------------===// -static LogicalResult verify(IndexOp op) { - auto linalgOp = dyn_cast(op->getParentOp()); +LogicalResult IndexOp::verify() { + auto linalgOp = dyn_cast((*this)->getParentOp()); if (!linalgOp) - return op.emitOpError("expected parent op with LinalgOp interface"); - if (linalgOp.getNumLoops() <= op.dim()) - return op.emitOpError("expected dim (") - << op.dim() << ") to be lower than the number of loops (" + return emitOpError("expected parent op with LinalgOp interface"); + if (linalgOp.getNumLoops() <= dim()) + return emitOpError("expected dim (") + << dim() << ") to be lower than the number of loops (" << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; return success(); } diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -668,28 +668,22 @@ LoopOp::getOperandSegmentSizeAttr()}); } -static LogicalResult verifyLoopOp(acc::LoopOp loopOp) { +LogicalResult acc::LoopOp::verify() { // auto, independent and seq attribute are mutually exclusive. - if ((loopOp.auto_() && (loopOp.independent() || loopOp.seq())) || - (loopOp.independent() && loopOp.seq())) { - loopOp.emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " + + if ((auto_() && (independent() || seq())) || (independent() && seq())) { + return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " + acc::LoopOp::getIndependentAttrName() + ", " + acc::LoopOp::getSeqAttrName() + " can be present at the same time"); - return failure(); } // Gang, worker and vector are incompatible with seq. - if (loopOp.seq() && loopOp.exec_mapping() != OpenACCExecMapping::NONE) { - loopOp.emitError("gang, worker or vector cannot appear with the seq attr"); - return failure(); - } + if (seq() && exec_mapping() != OpenACCExecMapping::NONE) + return emitError("gang, worker or vector cannot appear with the seq attr"); // Check non-empty body(). - if (loopOp.region().empty()) { - loopOp.emitError("expected non-empty body."); - return failure(); - } + if (region().empty()) + return emitError("expected non-empty body."); return success(); } @@ -698,13 +692,13 @@ // DataOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::DataOp dataOp) { +LogicalResult acc::DataOp::verify() { // 2.6.5. Data Construct restriction // At least one copy, copyin, copyout, create, no_create, present, deviceptr, // attach, or default clause must appear on a data construct. - if (dataOp.getOperands().empty() && !dataOp.defaultAttr()) - return dataOp.emitError("at least one operand or the default attribute " - "must appear on the data operation"); + if (getOperands().empty() && !defaultAttr()) + return emitError("at least one operand or the default attribute " + "must appear on the data operation"); return success(); } @@ -726,28 +720,28 @@ // ExitDataOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::ExitDataOp op) { +LogicalResult acc::ExitDataOp::verify() { // 2.6.6. Data Exit Directive restriction // At least one copyout, delete, or detach clause must appear on an exit data // directive. - if (op.copyoutOperands().empty() && op.deleteOperands().empty() && - op.detachOperands().empty()) - return op.emitError( + if (copyoutOperands().empty() && deleteOperands().empty() && + detachOperands().empty()) + return emitError( "at least one operand in copyout, delete or detach must appear on the " "exit data operation"); // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. - if (op.asyncOperand() && op.async()) - return op.emitError("async attribute cannot appear with asyncOperand"); + if (asyncOperand() && async()) + return emitError("async attribute cannot appear with asyncOperand"); // The wait attribute represent the wait clause without values. Therefore the // attribute and operands cannot appear at the same time. - if (!op.waitOperands().empty() && op.wait()) - return op.emitError("wait attribute cannot appear with waitOperands"); + if (!waitOperands().empty() && wait()) + return emitError("wait attribute cannot appear with waitOperands"); - if (op.waitDevnum() && op.waitOperands().empty()) - return op.emitError("wait_devnum cannot appear without waitOperands"); + if (waitDevnum() && waitOperands().empty()) + return emitError("wait_devnum cannot appear without waitOperands"); return success(); } @@ -773,28 +767,28 @@ // EnterDataOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::EnterDataOp op) { +LogicalResult acc::EnterDataOp::verify() { // 2.6.6. Data Enter Directive restriction // At least one copyin, create, or attach clause must appear on an enter data // directive. - if (op.copyinOperands().empty() && op.createOperands().empty() && - op.createZeroOperands().empty() && op.attachOperands().empty()) - return op.emitError( + if (copyinOperands().empty() && createOperands().empty() && + createZeroOperands().empty() && attachOperands().empty()) + return emitError( "at least one operand in copyin, create, " "create_zero or attach must appear on the enter data operation"); // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. - if (op.asyncOperand() && op.async()) - return op.emitError("async attribute cannot appear with asyncOperand"); + if (asyncOperand() && async()) + return emitError("async attribute cannot appear with asyncOperand"); // The wait attribute represent the wait clause without values. Therefore the // attribute and operands cannot appear at the same time. - if (!op.waitOperands().empty() && op.wait()) - return op.emitError("wait attribute cannot appear with waitOperands"); + if (!waitOperands().empty() && wait()) + return emitError("wait attribute cannot appear with waitOperands"); - if (op.waitDevnum() && op.waitOperands().empty()) - return op.emitError("wait_devnum cannot appear without waitOperands"); + if (waitDevnum() && waitOperands().empty()) + return emitError("wait_devnum cannot appear without waitOperands"); return success(); } @@ -820,12 +814,11 @@ // InitOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::InitOp initOp) { - Operation *currOp = initOp; - while ((currOp = currOp->getParentOp())) { +LogicalResult acc::InitOp::verify() { + Operation *currOp = *this; + while ((currOp = currOp->getParentOp())) if (isComputeOperation(currOp)) - return initOp.emitOpError("cannot be nested in a compute operation"); - } + return emitOpError("cannot be nested in a compute operation"); return success(); } @@ -833,12 +826,11 @@ // ShutdownOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::ShutdownOp op) { - Operation *currOp = op; - while ((currOp = currOp->getParentOp())) { +LogicalResult acc::ShutdownOp::verify() { + Operation *currOp = *this; + while ((currOp = currOp->getParentOp())) if (isComputeOperation(currOp)) - return op.emitOpError("cannot be nested in a compute operation"); - } + return emitOpError("cannot be nested in a compute operation"); return success(); } @@ -846,25 +838,24 @@ // UpdateOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::UpdateOp updateOp) { +LogicalResult acc::UpdateOp::verify() { // At least one of host or device should have a value. - if (updateOp.hostOperands().empty() && updateOp.deviceOperands().empty()) - return updateOp.emitError("at least one value must be present in" - " hostOperands or deviceOperands"); + if (hostOperands().empty() && deviceOperands().empty()) + return emitError( + "at least one value must be present in hostOperands or deviceOperands"); // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. - if (updateOp.asyncOperand() && updateOp.async()) - return updateOp.emitError("async attribute cannot appear with " - " asyncOperand"); + if (asyncOperand() && async()) + return emitError("async attribute cannot appear with asyncOperand"); // The wait attribute represent the wait clause without values. Therefore the // attribute and operands cannot appear at the same time. - if (!updateOp.waitOperands().empty() && updateOp.wait()) - return updateOp.emitError("wait attribute cannot appear with waitOperands"); + if (!waitOperands().empty() && wait()) + return emitError("wait attribute cannot appear with waitOperands"); - if (updateOp.waitDevnum() && updateOp.waitOperands().empty()) - return updateOp.emitError("wait_devnum cannot appear without waitOperands"); + if (waitDevnum() && waitOperands().empty()) + return emitError("wait_devnum cannot appear without waitOperands"); return success(); } @@ -890,14 +881,14 @@ // WaitOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::WaitOp waitOp) { +LogicalResult acc::WaitOp::verify() { // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. - if (waitOp.asyncOperand() && waitOp.async()) - return waitOp.emitError("async attribute cannot appear with asyncOperand"); + if (asyncOperand() && async()) + return emitError("async attribute cannot appear with asyncOperand"); - if (waitOp.waitDevnum() && waitOp.waitOperands().empty()) - return waitOp.emitError("wait_devnum cannot appear without waitOperands"); + if (waitDevnum() && waitOperands().empty()) + return emitError("wait_devnum cannot appear without waitOperands"); return success(); } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -165,9 +165,9 @@ } } -static LogicalResult verifyParallelOp(ParallelOp op) { - if (op.allocate_vars().size() != op.allocators_vars().size()) - return op.emitError( +LogicalResult ParallelOp::verify() { + if (allocate_vars().size() != allocators_vars().size()) + return emitError( "expected equal sizes for allocate and allocator variables"); return success(); } @@ -1072,31 +1072,31 @@ p.printRegion(op.region()); } -static LogicalResult verifySectionsOp(SectionsOp op) { - +LogicalResult SectionsOp::verify() { // A list item may not appear in more than one clause on the same directive, // except that it may be specified in both firstprivate and lastprivate // clauses. - for (auto var : op.private_vars()) { - if (llvm::is_contained(op.firstprivate_vars(), var)) - return op.emitOpError() + for (auto var : private_vars()) { + if (llvm::is_contained(firstprivate_vars(), var)) + return emitOpError() << "operand used in both private and firstprivate clauses"; - if (llvm::is_contained(op.lastprivate_vars(), var)) - return op.emitOpError() + if (llvm::is_contained(lastprivate_vars(), var)) + return emitOpError() << "operand used in both private and lastprivate clauses"; } - if (op.allocate_vars().size() != op.allocators_vars().size()) - return op.emitError( + if (allocate_vars().size() != allocators_vars().size()) + return emitError( "expected equal sizes for allocate and allocator variables"); - for (auto &inst : *op.region().begin()) { - if (!(isa(inst) || isa(inst))) - op.emitOpError() - << "expected omp.section op or terminator op inside region"; + for (auto &inst : *region().begin()) { + if (!(isa(inst) || isa(inst))) { + return emitOpError() + << "expected omp.section op or terminator op inside region"; + } } - return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); + return verifyReductionVarList(*this, reductions(), reduction_vars()); } /// Parses an OpenMP Workshare Loop operation @@ -1224,65 +1224,65 @@ printer.printRegion(region); } -static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { - if (op.initializerRegion().empty()) - return op.emitOpError() << "expects non-empty initializer region"; - Block &initializerEntryBlock = op.initializerRegion().front(); +LogicalResult ReductionDeclareOp::verify() { + if (initializerRegion().empty()) + return emitOpError() << "expects non-empty initializer region"; + Block &initializerEntryBlock = initializerRegion().front(); if (initializerEntryBlock.getNumArguments() != 1 || - initializerEntryBlock.getArgument(0).getType() != op.type()) { - return op.emitOpError() << "expects initializer region with one argument " - "of the reduction type"; + initializerEntryBlock.getArgument(0).getType() != type()) { + return emitOpError() << "expects initializer region with one argument " + "of the reduction type"; } - for (YieldOp yieldOp : op.initializerRegion().getOps()) { + for (YieldOp yieldOp : initializerRegion().getOps()) { if (yieldOp.results().size() != 1 || - yieldOp.results().getTypes()[0] != op.type()) - return op.emitOpError() << "expects initializer region to yield a value " - "of the reduction type"; + yieldOp.results().getTypes()[0] != type()) + return emitOpError() << "expects initializer region to yield a value " + "of the reduction type"; } - if (op.reductionRegion().empty()) - return op.emitOpError() << "expects non-empty reduction region"; - Block &reductionEntryBlock = op.reductionRegion().front(); + if (reductionRegion().empty()) + return emitOpError() << "expects non-empty reduction region"; + Block &reductionEntryBlock = reductionRegion().front(); if (reductionEntryBlock.getNumArguments() != 2 || reductionEntryBlock.getArgumentTypes()[0] != reductionEntryBlock.getArgumentTypes()[1] || - reductionEntryBlock.getArgumentTypes()[0] != op.type()) - return op.emitOpError() << "expects reduction region with two arguments of " - "the reduction type"; - for (YieldOp yieldOp : op.reductionRegion().getOps()) { + reductionEntryBlock.getArgumentTypes()[0] != type()) + return emitOpError() << "expects reduction region with two arguments of " + "the reduction type"; + for (YieldOp yieldOp : reductionRegion().getOps()) { if (yieldOp.results().size() != 1 || - yieldOp.results().getTypes()[0] != op.type()) - return op.emitOpError() << "expects reduction region to yield a value " - "of the reduction type"; + yieldOp.results().getTypes()[0] != type()) + return emitOpError() << "expects reduction region to yield a value " + "of the reduction type"; } - if (op.atomicReductionRegion().empty()) + if (atomicReductionRegion().empty()) return success(); - Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); + Block &atomicReductionEntryBlock = atomicReductionRegion().front(); if (atomicReductionEntryBlock.getNumArguments() != 2 || atomicReductionEntryBlock.getArgumentTypes()[0] != atomicReductionEntryBlock.getArgumentTypes()[1]) - return op.emitOpError() << "expects atomic reduction region with two " - "arguments of the same type"; + return emitOpError() << "expects atomic reduction region with two " + "arguments of the same type"; auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] .dyn_cast(); - if (!ptrType || ptrType.getElementType() != op.type()) - return op.emitOpError() << "expects atomic reduction region arguments to " - "be accumulators containing the reduction type"; + if (!ptrType || ptrType.getElementType() != type()) + return emitOpError() << "expects atomic reduction region arguments to " + "be accumulators containing the reduction type"; return success(); } -static LogicalResult verifyReductionOp(ReductionOp op) { +LogicalResult ReductionOp::verify() { // TODO: generalize this to an op interface when there is more than one op // that supports reductions. - auto container = op->getParentOfType(); + auto container = (*this)->getParentOfType(); for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) - if (container.reduction_vars()[i] == op.accumulator()) + if (container.reduction_vars()[i] == accumulator()) return success(); - return op.emitOpError() << "the accumulator is not used by the parent"; + return emitOpError() << "the accumulator is not used by the parent"; } //===----------------------------------------------------------------------===// @@ -1368,27 +1368,26 @@ } } -static LogicalResult verifyWsLoopOp(WsLoopOp op) { - return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); +LogicalResult WsLoopOp::verify() { + return verifyReductionVarList(*this, reductions(), reduction_vars()); } //===----------------------------------------------------------------------===// // Verifier for critical construct (2.17.1) //===----------------------------------------------------------------------===// -static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { - return verifySynchronizationHint(op, op.hint()); +LogicalResult CriticalDeclareOp::verify() { + return verifySynchronizationHint(*this, hint()); } -static LogicalResult verifyCriticalOp(CriticalOp op) { - - if (op.nameAttr()) { - auto symbolRef = op.nameAttr().cast(); - auto decl = - SymbolTable::lookupNearestSymbolFrom(op, symbolRef); +LogicalResult CriticalOp::verify() { + if (nameAttr()) { + SymbolRefAttr symbolRef = nameAttr(); + auto decl = SymbolTable::lookupNearestSymbolFrom( + *this, symbolRef); if (!decl) { - return op.emitOpError() << "expected symbol reference " << symbolRef - << " to point to a critical declaration"; + return emitOpError() << "expected symbol reference " << symbolRef + << " to point to a critical declaration"; } } @@ -1399,34 +1398,34 @@ // Verifier for ordered construct //===----------------------------------------------------------------------===// -static LogicalResult verifyOrderedOp(OrderedOp op) { - auto container = op->getParentOfType(); +LogicalResult OrderedOp::verify() { + auto container = (*this)->getParentOfType(); if (!container || !container.ordered_valAttr() || container.ordered_valAttr().getInt() == 0) - return op.emitOpError() << "ordered depend directive must be closely " - << "nested inside a worksharing-loop with ordered " - << "clause with parameter present"; + return emitOpError() << "ordered depend directive must be closely " + << "nested inside a worksharing-loop with ordered " + << "clause with parameter present"; if (container.ordered_valAttr().getInt() != - (int64_t)op.num_loops_val().getValue()) - return op.emitOpError() << "number of variables in depend clause does not " - << "match number of iteration variables in the " - << "doacross loop"; + (int64_t)num_loops_val().getValue()) + return emitOpError() << "number of variables in depend clause does not " + << "match number of iteration variables in the " + << "doacross loop"; return success(); } -static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) { +LogicalResult OrderedRegionOp::verify() { // TODO: The code generation for ordered simd directive is not supported yet. - if (op.simd()) + if (simd()) return failure(); - if (auto container = op->getParentOfType()) { + if (auto container = (*this)->getParentOfType()) { if (!container.ordered_valAttr() || container.ordered_valAttr().getInt() != 0) - return op.emitOpError() << "ordered region must be closely nested inside " - << "a worksharing-loop region with an ordered " - << "clause without parameter present"; + return emitOpError() << "ordered region must be closely nested inside " + << "a worksharing-loop region with an ordered " + << "clause without parameter present"; } return success(); @@ -1468,18 +1467,18 @@ } /// Verifier for AtomicReadOp -static LogicalResult verifyAtomicReadOp(AtomicReadOp op) { - if (auto mo = op.memory_order()) { +LogicalResult AtomicReadOp::verify() { + if (auto mo = memory_order()) { if (*mo == ClauseMemoryOrderKind::acq_rel || *mo == ClauseMemoryOrderKind::release) { - return op.emitError( + return emitError( "memory-order must not be acq_rel or release for atomic reads"); } } - if (op.x() == op.v()) - return op.emitError( + if (x() == v()) + return emitError( "read and write must not be to the same location for atomic reads"); - return verifySynchronizationHint(op, op.hint()); + return verifySynchronizationHint(*this, hint()); } //===----------------------------------------------------------------------===// @@ -1521,15 +1520,15 @@ } /// Verifier for AtomicWriteOp -static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) { - if (auto mo = op.memory_order()) { +LogicalResult AtomicWriteOp::verify() { + if (auto mo = memory_order()) { if (*mo == ClauseMemoryOrderKind::acq_rel || *mo == ClauseMemoryOrderKind::acquire) { - return op.emitError( + return emitError( "memory-order must not be acq_rel or acquire for atomic writes"); } } - return verifySynchronizationHint(op, op.hint()); + return verifySynchronizationHint(*this, hint()); } //===----------------------------------------------------------------------===// @@ -1601,11 +1600,11 @@ } /// Verifier for AtomicUpdateOp -static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) { - if (auto mo = op.memory_order()) { +LogicalResult AtomicUpdateOp::verify() { + if (auto mo = memory_order()) { if (*mo == ClauseMemoryOrderKind::acq_rel || *mo == ClauseMemoryOrderKind::acquire) { - return op.emitError( + return emitError( "memory-order must not be acq_rel or acquire for atomic updates"); } } @@ -1637,10 +1636,10 @@ } /// Verifier for AtomicCaptureOp -static LogicalResult verifyAtomicCaptureOp(AtomicCaptureOp op) { - Block::OpListType &ops = op.region().front().getOperations(); +LogicalResult AtomicCaptureOp::verify() { + Block::OpListType &ops = region().front().getOperations(); if (ops.size() != 3) - return emitError(op.getLoc()) + return emitError() << "expected three operations in omp.atomic.capture region (one " "terminator, and two atomic ops)"; auto &firstOp = ops.front(); @@ -1654,21 +1653,21 @@ if (!((firstUpdateStmt && secondReadStmt) || (firstReadStmt && secondUpdateStmt) || (firstReadStmt && secondWriteStmt))) - return emitError(ops.front().getLoc()) + return ops.front().emitError() << "invalid sequence of operations in the capture region"; if (firstUpdateStmt && secondReadStmt && firstUpdateStmt.x() != secondReadStmt.x()) - return emitError(firstUpdateStmt.getLoc()) + return firstUpdateStmt.emitError() << "updated variable in omp.atomic.update must be captured in " "second operation"; if (firstReadStmt && secondUpdateStmt && firstReadStmt.x() != secondUpdateStmt.x()) - return emitError(firstReadStmt.getLoc()) + return firstReadStmt.emitError() << "captured variable in omp.atomic.read must be updated in second " "operation"; if (firstReadStmt && secondWriteStmt && firstReadStmt.x() != secondWriteStmt.address()) - return emitError(firstReadStmt.getLoc()) + return firstReadStmt.emitError() << "captured variable in omp.atomic.read must be updated in " "second operation"; return success(); diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -90,9 +90,9 @@ // pdl::ApplyNativeConstraintOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ApplyNativeConstraintOp op) { - if (op.getNumOperands() == 0) - return op.emitOpError("expected at least one argument"); +LogicalResult ApplyNativeConstraintOp::verify() { + if (getNumOperands() == 0) + return emitOpError("expected at least one argument"); return success(); } @@ -100,9 +100,9 @@ // pdl::ApplyNativeRewriteOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ApplyNativeRewriteOp op) { - if (op.getNumOperands() == 0 && op.getNumResults() == 0) - return op.emitOpError("expected at least one argument or result"); +LogicalResult ApplyNativeRewriteOp::verify() { + if (getNumOperands() == 0 && getNumResults() == 0) + return emitOpError("expected at least one argument or result"); return success(); } @@ -110,18 +110,18 @@ // pdl::AttributeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AttributeOp op) { - Value attrType = op.type(); - Optional attrValue = op.value(); +LogicalResult AttributeOp::verify() { + Value attrType = type(); + Optional attrValue = value(); if (!attrValue) { - if (isa(op->getParentOp())) - return op.emitOpError("expected constant value when specified within a " - "`pdl.rewrite`"); - return verifyHasBindingUse(op); + if (isa((*this)->getParentOp())) + return emitOpError( + "expected constant value when specified within a `pdl.rewrite`"); + return verifyHasBindingUse(*this); } if (attrType) - return op.emitOpError("expected only one of [`type`, `value`] to be set"); + return emitOpError("expected only one of [`type`, `value`] to be set"); return success(); } @@ -129,13 +129,13 @@ // pdl::OperandOp //===----------------------------------------------------------------------===// -static LogicalResult verify(OperandOp op) { return verifyHasBindingUse(op); } +LogicalResult OperandOp::verify() { return verifyHasBindingUse(*this); } //===----------------------------------------------------------------------===// // pdl::OperandsOp //===----------------------------------------------------------------------===// -static LogicalResult verify(OperandsOp op) { return verifyHasBindingUse(op); } +LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); } //===----------------------------------------------------------------------===// // pdl::OperationOp @@ -230,15 +230,15 @@ return success(); } -static LogicalResult verify(OperationOp op) { - bool isWithinRewrite = isa(op->getParentOp()); - if (isWithinRewrite && !op.name()) - return op.emitOpError("must have an operation name when nested within " - "a `pdl.rewrite`"); - ArrayAttr attributeNames = op.attributeNames(); - auto attributeValues = op.attributes(); +LogicalResult OperationOp::verify() { + bool isWithinRewrite = isa((*this)->getParentOp()); + if (isWithinRewrite && !name()) + return emitOpError("must have an operation name when nested within " + "a `pdl.rewrite`"); + ArrayAttr attributeNames = attributeNamesAttr(); + auto attributeValues = attributes(); if (attributeNames.size() != attributeValues.size()) { - return op.emitOpError() + return emitOpError() << "expected the same number of attribute values and attribute " "names, got " << attributeNames.size() << " names and " << attributeValues.size() @@ -247,12 +247,12 @@ // If the operation is within a rewrite body and doesn't have type inference, // ensure that the result types can be resolved. - if (isWithinRewrite && !op.hasTypeInference()) { - if (failed(verifyResultTypesAreInferrable(op, op.types()))) + if (isWithinRewrite && !hasTypeInference()) { + if (failed(verifyResultTypesAreInferrable(*this, types()))) return failure(); } - return verifyHasBindingUse(op); + return verifyHasBindingUse(*this); } bool OperationOp::hasTypeInference() { @@ -269,12 +269,12 @@ // pdl::PatternOp //===----------------------------------------------------------------------===// -static LogicalResult verify(PatternOp pattern) { - Region &body = pattern.body(); +LogicalResult PatternOp::verify() { + Region &body = getBodyRegion(); Operation *term = body.front().getTerminator(); auto rewriteOp = dyn_cast(term); if (!rewriteOp) { - return pattern.emitOpError("expected body to terminate with `pdl.rewrite`") + return emitOpError("expected body to terminate with `pdl.rewrite`") .attachNote(term->getLoc()) .append("see terminator defined here"); } @@ -283,8 +283,7 @@ // dialect. WalkResult result = body.walk([&](Operation *op) -> WalkResult { if (!isa_and_nonnull(op->getDialect())) { - pattern - .emitOpError("expected only `pdl` operations within the pattern body") + emitOpError("expected only `pdl` operations within the pattern body") .attachNote(op->getLoc()) .append("see non-`pdl` operation defined here"); return WalkResult::interrupt(); @@ -296,8 +295,7 @@ // Check that there is at least one operation. if (body.front().getOps().empty()) - return pattern.emitOpError( - "the pattern must contain at least one `pdl.operation`"); + return emitOpError("the pattern must contain at least one `pdl.operation`"); // Determine if the operations within the pdl.pattern form a connected // component. This is determined by starting the search from the first @@ -333,8 +331,7 @@ first = false; } else if (!visited.count(&op)) { // For the subsequent operations, check if already visited. - return pattern - .emitOpError("the operations must form a connected component") + return emitOpError("the operations must form a connected component") .attachNote(op.getLoc()) .append("see a disconnected value / operation here"); } @@ -364,10 +361,10 @@ // pdl::ReplaceOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReplaceOp op) { - if (op.replOperation() && !op.replValues().empty()) - return op.emitOpError() << "expected no replacement values to be provided" - " when the replacement operation is present"; +LogicalResult ReplaceOp::verify() { + if (replOperation() && !replValues().empty()) + return emitOpError() << "expected no replacement values to be provided" + " when the replacement operation is present"; return success(); } @@ -392,11 +389,11 @@ p << " -> " << resultType; } -static LogicalResult verify(ResultsOp op) { - if (!op.index() && op.getType().isa()) { - return op.emitOpError() << "expected `pdl.range` result type when " - "no index is specified, but got: " - << op.getType(); +LogicalResult ResultsOp::verify() { + if (!index() && getType().isa()) { + return emitOpError() << "expected `pdl.range` result type when " + "no index is specified, but got: " + << getType(); } return success(); } @@ -405,13 +402,13 @@ // pdl::RewriteOp //===----------------------------------------------------------------------===// -static LogicalResult verify(RewriteOp op) { - Region &rewriteRegion = op.body(); +LogicalResult RewriteOp::verify() { + Region &rewriteRegion = body(); // Handle the case where the rewrite is external. - if (op.name()) { + if (name()) { if (!rewriteRegion.empty()) { - return op.emitOpError() + return emitOpError() << "expected rewrite region to be empty when rewrite is external"; } return success(); @@ -419,18 +416,18 @@ // Otherwise, check that the rewrite region only contains a single block. if (rewriteRegion.empty()) { - return op.emitOpError() << "expected rewrite region to be non-empty if " - "external name is not specified"; + return emitOpError() << "expected rewrite region to be non-empty if " + "external name is not specified"; } // Check that no additional arguments were provided. - if (!op.externalArgs().empty()) { - return op.emitOpError() << "expected no external arguments when the " - "rewrite is specified inline"; + if (!externalArgs().empty()) { + return emitOpError() << "expected no external arguments when the " + "rewrite is specified inline"; } - if (op.externalConstParams()) { - return op.emitOpError() << "expected no external constant parameters when " - "the rewrite is specified inline"; + if (externalConstParams()) { + return emitOpError() << "expected no external constant parameters when " + "the rewrite is specified inline"; } return success(); @@ -445,9 +442,9 @@ // pdl::TypeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(TypeOp op) { - if (!op.typeAttr()) - return verifyHasBindingUse(op); +LogicalResult TypeOp::verify() { + if (!typeAttr()) + return verifyHasBindingUse(*this); return success(); } @@ -455,9 +452,9 @@ // pdl::TypesOp //===----------------------------------------------------------------------===// -static LogicalResult verify(TypesOp op) { - if (!op.typesAttr()) - return verifyHasBindingUse(op); +LogicalResult TypesOp::verify() { + if (!typesAttr()) + return verifyHasBindingUse(*this); return success(); } diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -27,6 +27,21 @@ >(); } +template +static LogicalResult verifySwitchOp(OpT op) { + // Verify that the number of case destinations matches the number of case + // values. + size_t numDests = op.cases().size(); + size_t numValues = op.caseValues().size(); + if (numDests != numValues) { + return op.emitOpError( + "expected number of cases to match the number of case " + "values, got ") + << numDests << " but expected " << numValues; + } + return success(); +} + //===----------------------------------------------------------------------===// // pdl_interp::CreateOperationOp //===----------------------------------------------------------------------===// @@ -131,17 +146,17 @@ p.printSuccessor(op.successor()); } -static LogicalResult verify(ForEachOp op) { +LogicalResult ForEachOp::verify() { // Verify that the operation has exactly one argument. - if (op.region().getNumArguments() != 1) - return op.emitOpError("requires exactly one argument"); + if (region().getNumArguments() != 1) + return emitOpError("requires exactly one argument"); // Verify that the loop variable and the operand (value range) // have compatible types. - BlockArgument arg = op.getLoopVariable(); + BlockArgument arg = getLoopVariable(); Type rangeType = pdl::RangeType::get(arg.getType()); - if (rangeType != op.values().getType()) - return op.emitOpError("operand must be a range of loop variable type"); + if (rangeType != values().getType()) + return emitOpError("operand must be a range of loop variable type"); return success(); } @@ -156,6 +171,42 @@ return type.isa() ? pdl::RangeType::get(valueTy) : valueTy; } +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchAttributeOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); } + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchOperandCountOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); } + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchOperationNameOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); } + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchResultCountOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); } + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchTypeOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); } + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchTypesOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); } + //===----------------------------------------------------------------------===// // TableGen Auto-Generated Op and Interface Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -65,29 +65,67 @@ return false; } -static LogicalResult verifyRegionOp(QuantizeRegionOp op) { +LogicalResult QuantizeRegionOp::verify() { // There are specifications for both inputs and outputs. - if (op.getNumOperands() != op.input_specs().size() || - op.getNumResults() != op.output_specs().size()) - return op.emitOpError( + if (getNumOperands() != input_specs().size() || + getNumResults() != output_specs().size()) + return emitOpError( "has unmatched operands/results number and spec attributes number"); // Verify that quantization specifications are valid. - for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) { + for (auto input : llvm::zip(getOperandTypes(), input_specs())) { Type inputType = std::get<0>(input); Attribute inputSpec = std::get<1>(input); if (!isValidQuantizationSpec(inputSpec, inputType)) { - return op.emitOpError() << "has incompatible specification " << inputSpec - << " and input type " << inputType; + return emitOpError() << "has incompatible specification " << inputSpec + << " and input type " << inputType; } } - for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) { + for (auto result : llvm::zip(getResultTypes(), output_specs())) { Type outputType = std::get<0>(result); Attribute outputSpec = std::get<1>(result); if (!isValidQuantizationSpec(outputSpec, outputType)) { - return op.emitOpError() << "has incompatible specification " << outputSpec - << " and output type " << outputType; + return emitOpError() << "has incompatible specification " << outputSpec + << " and output type " << outputType; + } + } + return success(); +} + +LogicalResult StatisticsOp::verify() { + auto tensorArg = arg().getType().dyn_cast(); + if (!tensorArg) + return emitOpError("arg needs to be tensor type."); + + // Verify layerStats attribute. + { + auto layerStatsType = layerStats().getType(); + if (!layerStatsType.getElementType().isa()) { + return emitOpError("layerStats must have a floating point element type"); + } + if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { + return emitOpError("layerStats must have shape [2]"); + } + } + // Verify axisStats (optional) attribute. + if (axisStats()) { + if (!axis()) + return emitOpError("axis must be specified for axisStats"); + + auto shape = tensorArg.getShape(); + auto argSliceSize = + std::accumulate(std::next(shape.begin(), *axis()), shape.end(), 1, + std::multiplies()); + + auto axisStatsType = axisStats()->getType(); + if (!axisStatsType.getElementType().isa()) { + return emitOpError("axisStats must have a floating point element type"); + } + if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 || + axisStatsType.getDimSize(0) != argSliceSize) { + return emitOpError("axisStats must have shape [N,2] " + "where N = the slice size defined by the axis dim"); } } return success(); diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -125,11 +125,11 @@ p.printOptionalAttrDict(op->getAttrs()); } -static LogicalResult verify(ExecuteRegionOp op) { - if (op.getRegion().empty()) - return op.emitOpError("region needs to have at least one block"); - if (op.getRegion().front().getNumArguments() > 0) - return op.emitOpError("region cannot have any arguments"); +LogicalResult ExecuteRegionOp::verify() { + if (getRegion().empty()) + return emitOpError("region needs to have at least one block"); + if (getRegion().front().getNumArguments() > 0) + return emitOpError("region cannot have any arguments"); return success(); } @@ -276,47 +276,47 @@ } } -static LogicalResult verify(ForOp op) { - if (auto cst = op.getStep().getDefiningOp()) +LogicalResult ForOp::verify() { + if (auto cst = getStep().getDefiningOp()) if (cst.value() <= 0) - return op.emitOpError("constant step operand must be positive"); + return emitOpError("constant step operand must be positive"); // Check that the body defines as single block argument for the induction // variable. - auto *body = op.getBody(); + auto *body = getBody(); if (!body->getArgument(0).getType().isIndex()) - return op.emitOpError( + return emitOpError( "expected body first argument to be an index argument for " "the induction variable"); - auto opNumResults = op.getNumResults(); + auto opNumResults = getNumResults(); if (opNumResults == 0) return success(); // If ForOp defines values, check that the number and types of // the defined values match ForOp initial iter operands and backedge // basic block arguments. - if (op.getNumIterOperands() != opNumResults) - return op.emitOpError( + if (getNumIterOperands() != opNumResults) + return emitOpError( "mismatch in number of loop-carried values and defined values"); - if (op.getNumRegionIterArgs() != opNumResults) - return op.emitOpError( + if (getNumRegionIterArgs() != opNumResults) + return emitOpError( "mismatch in number of basic block args and defined values"); - auto iterOperands = op.getIterOperands(); - auto iterArgs = op.getRegionIterArgs(); - auto opResults = op.getResults(); + auto iterOperands = getIterOperands(); + auto iterArgs = getRegionIterArgs(); + auto opResults = getResults(); unsigned i = 0; for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { if (std::get<0>(e).getType() != std::get<2>(e).getType()) - return op.emitOpError() << "types mismatch between " << i - << "th iter operand and defined value"; + return emitOpError() << "types mismatch between " << i + << "th iter operand and defined value"; if (std::get<1>(e).getType() != std::get<2>(e).getType()) - return op.emitOpError() << "types mismatch between " << i - << "th iter region arg and defined value"; + return emitOpError() << "types mismatch between " << i + << "th iter region arg and defined value"; i++; } - return RegionBranchOpInterface::verifyTypes(op); + return RegionBranchOpInterface::verifyTypes(*this); } /// Prints the initialization list in the form of @@ -1062,11 +1062,11 @@ build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder); } -static LogicalResult verify(IfOp op) { - if (op.getNumResults() != 0 && op.getElseRegion().empty()) - return op.emitOpError("must have an else block if defining values"); +LogicalResult IfOp::verify() { + if (getNumResults() != 0 && getElseRegion().empty()) + return emitOpError("must have an else block if defining values"); - return RegionBranchOpInterface::verifyTypes(op); + return RegionBranchOpInterface::verifyTypes(*this); } static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) { @@ -1723,32 +1723,31 @@ wrapper); } -static LogicalResult verify(ParallelOp op) { +LogicalResult ParallelOp::verify() { // Check that there is at least one value in lowerBound, upperBound and step. // It is sufficient to test only step, because it is ensured already that the // number of elements in lowerBound, upperBound and step are the same. - Operation::operand_range stepValues = op.getStep(); + Operation::operand_range stepValues = getStep(); if (stepValues.empty()) - return op.emitOpError( + return emitOpError( "needs at least one tuple element for lowerBound, upperBound and step"); // Check whether all constant step values are positive. for (Value stepValue : stepValues) if (auto cst = stepValue.getDefiningOp()) if (cst.value() <= 0) - return op.emitOpError("constant step operand must be positive"); + return emitOpError("constant step operand must be positive"); // Check that the body defines the same number of block arguments as the // number of tuple elements in step. - Block *body = op.getBody(); + Block *body = getBody(); if (body->getNumArguments() != stepValues.size()) - return op.emitOpError() - << "expects the same number of induction variables: " - << body->getNumArguments() - << " as bound and step values: " << stepValues.size(); + return emitOpError() << "expects the same number of induction variables: " + << body->getNumArguments() + << " as bound and step values: " << stepValues.size(); for (auto arg : body->getArguments()) if (!arg.getType().isIndex()) - return op.emitOpError( + return emitOpError( "expects arguments for the induction variable to be of index type"); // Check that the yield has no results @@ -1759,20 +1758,20 @@ // Check that the number of results is the same as the number of ReduceOps. SmallVector reductions(body->getOps()); - auto resultsSize = op.getResults().size(); + auto resultsSize = getResults().size(); auto reductionsSize = reductions.size(); - auto initValsSize = op.getInitVals().size(); + auto initValsSize = getInitVals().size(); if (resultsSize != reductionsSize) - return op.emitOpError() - << "expects number of results: " << resultsSize - << " to be the same as number of reductions: " << reductionsSize; + return emitOpError() << "expects number of results: " << resultsSize + << " to be the same as number of reductions: " + << reductionsSize; if (resultsSize != initValsSize) - return op.emitOpError() - << "expects number of results: " << resultsSize - << " to be the same as number of initial values: " << initValsSize; + return emitOpError() << "expects number of results: " << resultsSize + << " to be the same as number of initial values: " + << initValsSize; // Check that the types of the results and reductions are the same. - for (auto resultAndReduce : llvm::zip(op.getResults(), reductions)) { + for (auto resultAndReduce : llvm::zip(getResults(), reductions)) { auto resultType = std::get<0>(resultAndReduce).getType(); auto reduceOp = std::get<1>(resultAndReduce); auto reduceType = reduceOp.getOperand().getType(); @@ -2075,23 +2074,23 @@ body->getArgument(1)); } -static LogicalResult verify(ReduceOp op) { +LogicalResult ReduceOp::verify() { // The region of a ReduceOp has two arguments of the same type as its operand. - auto type = op.getOperand().getType(); - Block &block = op.getReductionOperator().front(); + auto type = getOperand().getType(); + Block &block = getReductionOperator().front(); if (block.empty()) - return op.emitOpError("the block inside reduce should not be empty"); + return emitOpError("the block inside reduce should not be empty"); if (block.getNumArguments() != 2 || llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) { return arg.getType() != type; })) - return op.emitOpError() - << "expects two arguments to reduce block of type " << type; + return emitOpError() << "expects two arguments to reduce block of type " + << type; // Check that the block is terminated by a ReduceReturnOp. if (!isa(block.getTerminator())) - return op.emitOpError("the block inside reduce should be terminated with a " - "'scf.reduce.return' op"); + return emitOpError("the block inside reduce should be terminated with a " + "'scf.reduce.return' op"); return success(); } @@ -2127,14 +2126,14 @@ // ReduceReturnOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReduceReturnOp op) { +LogicalResult ReduceReturnOp::verify() { // The type of the return value should be the same type as the type of the // operand of the enclosing ReduceOp. - auto reduceOp = cast(op->getParentOp()); + auto reduceOp = cast((*this)->getParentOp()); Type reduceType = reduceOp.getOperand().getType(); - if (reduceType != op.getResult().getType()) - return op.emitOpError() << "needs to have type " << reduceType - << " (the type of the enclosing ReduceOp)"; + if (reduceType != getResult().getType()) + return emitOpError() << "needs to have type " << reduceType + << " (the type of the enclosing ReduceOp)"; return success(); } @@ -2278,18 +2277,18 @@ return nullptr; } -static LogicalResult verify(scf::WhileOp op) { - if (failed(RegionBranchOpInterface::verifyTypes(op))) +LogicalResult scf::WhileOp::verify() { + if (failed(RegionBranchOpInterface::verifyTypes(*this))) return failure(); auto beforeTerminator = verifyAndGetTerminator( - op, op.getBefore(), + *this, getBefore(), "expects the 'before' region to terminate with 'scf.condition'"); if (!beforeTerminator) return failure(); auto afterTerminator = verifyAndGetTerminator( - op, op.getAfter(), + *this, getAfter(), "expects the 'after' region to terminate with 'scf.yield'"); return success(afterTerminator != nullptr); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -419,6 +419,10 @@ result.addTypes(assumingTypes); } +LogicalResult AssumingOp::verify() { + return RegionBranchOpInterface::verifyTypes(*this); +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -449,6 +453,8 @@ operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); } +LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// @@ -529,10 +535,10 @@ return BoolAttr::get(getContext(), true); } -static LogicalResult verify(AssumingAllOp op) { +LogicalResult AssumingAllOp::verify() { // Ensure that AssumingAllOp contains at least one operand - if (op.getNumOperands() == 0) - return op.emitOpError("no operands specified"); + if (getNumOperands() == 0) + return emitOpError("no operands specified"); return success(); } @@ -575,8 +581,8 @@ return builder.getIndexTensorAttr(resultShape); } -static LogicalResult verify(BroadcastOp op) { - return verifyShapeOrExtentTensorOp(op); +LogicalResult BroadcastOp::verify() { + return verifyShapeOrExtentTensorOp(*this); } namespace { @@ -912,10 +918,10 @@ return nullptr; } -static LogicalResult verify(CstrBroadcastableOp op) { +LogicalResult CstrBroadcastableOp::verify() { // Ensure that AssumingAllOp contains at least one operand - if (op.getNumOperands() < 2) - return op.emitOpError("required at least 2 input shapes"); + if (getNumOperands() < 2) + return emitOpError("required at least 2 input shapes"); return success(); } @@ -1016,6 +1022,8 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // ShapeEqOp //===----------------------------------------------------------------------===// @@ -1172,6 +1180,8 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // IsBroadcastableOp //===----------------------------------------------------------------------===// @@ -1298,6 +1308,8 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// @@ -1333,6 +1345,10 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult shape::NumElementsOp::verify() { + return verifySizeOrIndexOp(*this); +} + //===----------------------------------------------------------------------===// // MaxOp //===----------------------------------------------------------------------===// @@ -1429,6 +1445,9 @@ // SizeType is compatible with IndexType. return eachHasOnlyOneOfTypes(l, r); } + +LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// @@ -1535,6 +1554,10 @@ return false; } +LogicalResult shape::ShapeOfOp::verify() { + return verifyShapeOrExtentTensorOp(*this); +} + //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// @@ -1556,18 +1579,17 @@ // YieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(shape::YieldOp op) { - auto *parentOp = op->getParentOp(); +LogicalResult shape::YieldOp::verify() { + auto *parentOp = (*this)->getParentOp(); auto results = parentOp->getResults(); - auto operands = op.getOperands(); + auto operands = getOperands(); - if (parentOp->getNumResults() != op.getNumOperands()) - return op.emitOpError() << "number of operands does not match number of " - "results of its parent"; + if (parentOp->getNumResults() != getNumOperands()) + return emitOpError() << "number of operands does not match number of " + "results of its parent"; for (auto e : llvm::zip(results, operands)) if (std::get<0>(e).getType() != std::get<1>(e).getType()) - return op.emitOpError() - << "types mismatch between yield op and its parent"; + return emitOpError() << "types mismatch between yield op and its parent"; return success(); } @@ -1639,41 +1661,42 @@ } } -static LogicalResult verify(ReduceOp op) { +LogicalResult ReduceOp::verify() { // Verify block arg types. - Block &block = op.getRegion().front(); + Block &block = getRegion().front(); // The block takes index, extent, and aggregated values as arguments. - auto blockArgsCount = op.getInitVals().size() + 2; + auto blockArgsCount = getInitVals().size() + 2; if (block.getNumArguments() != blockArgsCount) - return op.emitOpError() << "ReduceOp body is expected to have " - << blockArgsCount << " arguments"; + return emitOpError() << "ReduceOp body is expected to have " + << blockArgsCount << " arguments"; // The first block argument is the index and must always be of type `index`. if (!block.getArgument(0).getType().isa()) - return op.emitOpError( + return emitOpError( "argument 0 of ReduceOp body is expected to be of IndexType"); // The second block argument is the extent and must be of type `size` or // `index`, depending on whether the reduce operation is applied to a shape or // to an extent tensor. Type extentTy = block.getArgument(1).getType(); - if (op.getShape().getType().isa()) { + if (getShape().getType().isa()) { if (!extentTy.isa()) - return op.emitOpError("argument 1 of ReduceOp body is expected to be of " - "SizeType if the ReduceOp operates on a ShapeType"); + return emitOpError("argument 1 of ReduceOp body is expected to be of " + "SizeType if the ReduceOp operates on a ShapeType"); } else { if (!extentTy.isa()) - return op.emitOpError( + return emitOpError( "argument 1 of ReduceOp body is expected to be of IndexType if the " "ReduceOp operates on an extent tensor"); } - for (const auto &type : llvm::enumerate(op.getInitVals())) + for (const auto &type : llvm::enumerate(getInitVals())) if (block.getArgument(type.index() + 2).getType() != type.value().getType()) - return op.emitOpError() - << "type mismatch between argument " << type.index() + 2 - << " of ReduceOp body and initial value " << type.index(); + return emitOpError() << "type mismatch between argument " + << type.index() + 2 + << " of ReduceOp body and initial value " + << type.index(); return success(); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -612,31 +612,31 @@ /// The constant op requires an attribute, and furthermore requires that it /// matches the return type. -static LogicalResult verify(ConstantOp &op) { - auto value = op.getValue(); +LogicalResult ConstantOp::verify() { + auto value = getValue(); if (!value) - return op.emitOpError("requires a 'value' attribute"); + return emitOpError("requires a 'value' attribute"); - Type type = op.getType(); + Type type = getType(); if (!value.getType().isa() && type != value.getType()) - return op.emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; + return emitOpError() << "requires attribute's type (" << value.getType() + << ") to match op's return type (" << type << ")"; if (type.isa()) { auto fnAttr = value.dyn_cast(); if (!fnAttr) - return op.emitOpError("requires 'value' to be a function reference"); + return emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. - auto fn = - op->getParentOfType().lookupSymbol(fnAttr.getValue()); + auto fn = (*this)->getParentOfType().lookupSymbol( + fnAttr.getValue()); if (!fn) - return op.emitOpError() - << "reference to undefined function '" << fnAttr.getValue() << "'"; + return emitOpError() << "reference to undefined function '" + << fnAttr.getValue() << "'"; // Check that the referenced function has the correct type. if (fn.getType() != type) - return op.emitOpError("reference to function with mismatched type"); + return emitOpError("reference to function with mismatched type"); return success(); } @@ -644,7 +644,7 @@ if (type.isa() && value.isa()) return success(); - return op.emitOpError("unsupported 'value' attribute: ") << value; + return emitOpError("unsupported 'value' attribute: ") << value; } OpFoldResult ConstantOp::fold(ArrayRef operands) { @@ -676,23 +676,23 @@ // ReturnOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReturnOp op) { - auto function = cast(op->getParentOp()); +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); - if (op.getNumOperands() != results.size()) - return op.emitOpError("has ") - << op.getNumOperands() << " operands, but enclosing function (@" + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) - if (op.getOperand(i).getType() != results[i]) - return op.emitError() - << "type of return operand " << i << " (" - << op.getOperand(i).getType() - << ") doesn't match function result type (" << results[i] << ")" - << " in function @" << function.getName(); + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); return success(); } @@ -843,24 +843,23 @@ parser.getNameLoc(), result.operands); } -static LogicalResult verify(SelectOp op) { - Type conditionType = op.getCondition().getType(); +LogicalResult SelectOp::verify() { + Type conditionType = getCondition().getType(); if (conditionType.isSignlessInteger(1)) return success(); // If the result type is a vector or tensor, the type can be a mask with the // same elements. - Type resultType = op.getType(); + Type resultType = getType(); if (!resultType.isa()) - return op.emitOpError() - << "expected condition to be a signless i1, but got " - << conditionType; + return emitOpError() << "expected condition to be a signless i1, but got " + << conditionType; Type shapedConditionType = getI1SameShape(resultType); if (conditionType != shapedConditionType) - return op.emitOpError() - << "expected condition type to have the same shape " - "as the result type, expected " - << shapedConditionType << ", but got " << conditionType; + return emitOpError() << "expected condition type to have the same shape " + "as the result type, expected " + << shapedConditionType << ", but got " + << conditionType; return success(); } @@ -868,11 +867,10 @@ // SplatOp //===----------------------------------------------------------------------===// -static LogicalResult verify(SplatOp op) { +LogicalResult SplatOp::verify() { // TODO: we could replace this by a trait. - if (op.getOperand().getType() != - op.getType().cast().getElementType()) - return op.emitError("operand should be of elemental type of result type"); + if (getOperand().getType() != getType().cast().getElementType()) + return emitError("operand should be of elemental type of result type"); return success(); } @@ -995,26 +993,26 @@ p.printNewline(); } -static LogicalResult verify(SwitchOp op) { - auto caseValues = op.getCaseValues(); - auto caseDestinations = op.getCaseDestinations(); +LogicalResult SwitchOp::verify() { + auto caseValues = getCaseValues(); + auto caseDestinations = getCaseDestinations(); if (!caseValues && caseDestinations.empty()) return success(); - Type flagType = op.getFlag().getType(); + Type flagType = getFlag().getType(); Type caseValueType = caseValues->getType().getElementType(); if (caseValueType != flagType) - return op.emitOpError() - << "'flag' type (" << flagType << ") should match case value type (" - << caseValueType << ")"; + return emitOpError() << "'flag' type (" << flagType + << ") should match case value type (" << caseValueType + << ")"; if (caseValues && caseValues->size() != static_cast(caseDestinations.size())) - return op.emitOpError() << "number of case values (" << caseValues->size() - << ") should match number of " - "case destinations (" - << caseDestinations.size() << ")"; + return emitOpError() << "number of case values (" << caseValues->size() + << ") should match number of " + "case destinations (" + << caseDestinations.size() << ")"; return success(); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -701,9 +701,9 @@ return success(); } -static LogicalResult verifyAveragePoolOp(tosa::AvgPool2dOp op) { - auto inputETy = op.input().getType().cast().getElementType(); - auto resultETy = op.getType().cast().getElementType(); +LogicalResult tosa::AvgPool2dOp::verify() { + auto inputETy = input().getType().cast().getElementType(); + auto resultETy = getType().cast().getElementType(); if (auto quantType = inputETy.dyn_cast()) inputETy = quantType.getStorageType(); @@ -718,7 +718,7 @@ if (inputETy.isInteger(16) && resultETy.isInteger(16)) return success(); - return op.emitOpError("input/output element types are incompatible."); + return emitOpError("input/output element types are incompatible."); } //===----------------------------------------------------------------------===// @@ -1010,6 +1010,8 @@ return success(); } +LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); } + LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1632,6 +1634,8 @@ return success(); } +LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); } + LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1708,6 +1712,8 @@ return success(); } +LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); } + LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1800,6 +1806,8 @@ return success(); } +LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); } + LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -128,19 +128,19 @@ p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); } -static LogicalResult verify(FuncOp op) { +LogicalResult FuncOp::verify() { // If this function is external there is nothing to do. - if (op.isExternal()) + if (isExternal()) return success(); // Verify that the argument list of the function and the arg list of the entry // block line up. The trait already verified that the number of arguments is // the same between the signature and the block. - auto fnInputTypes = op.getType().getInputs(); - Block &entryBlock = op.front(); + auto fnInputTypes = getType().getInputs(); + Block &entryBlock = front(); for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) - return op.emitOpError("type of entry block argument #") + return emitOpError("type of entry block argument #") << i << '(' << entryBlock.getArgument(i).getType() << ") must match the type of the corresponding argument in " << "function signature(" << fnInputTypes[i] << ')'; @@ -245,28 +245,28 @@ return {}; } -static LogicalResult verify(ModuleOp op) { +LogicalResult ModuleOp::verify() { // Check that none of the attributes are non-dialect attributes, except for // the symbol related attributes. - for (auto attr : op->getAttrs()) { + for (auto attr : (*this)->getAttrs()) { if (!attr.getName().strref().contains('.') && !llvm::is_contained( ArrayRef{mlir::SymbolTable::getSymbolAttrName(), mlir::SymbolTable::getVisibilityAttrName()}, attr.getName().strref())) - return op.emitOpError() << "can only contain attributes with " - "dialect-prefixed names, found: '" - << attr.getName().getValue() << "'"; + return emitOpError() << "can only contain attributes with " + "dialect-prefixed names, found: '" + << attr.getName().getValue() << "'"; } // Check that there is at most one data layout spec attribute. StringRef layoutSpecAttrName; DataLayoutSpecInterface layoutSpec; - for (const NamedAttribute &na : op->getAttrs()) { + for (const NamedAttribute &na : (*this)->getAttrs()) { if (auto spec = na.getValue().dyn_cast()) { if (layoutSpec) { InFlightDiagnostic diag = - op.emitOpError() << "expects at most one data layout attribute"; + emitOpError() << "expects at most one data layout attribute"; diag.attachNote() << "'" << layoutSpecAttrName << "' is a data layout attribute"; diag.attachNote() << "'" << na.getName().getValue() diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -90,7 +90,7 @@ %cst = arith.constant 1 : index %value = memref.alloc() : memref<10xf32> -// expected-error@+1 {{async attribute cannot appear with asyncOperand}} +// expected-error@+1 {{async attribute cannot appear with asyncOperand}} acc.update async(%cst: index) host(%value: memref<10xf32>) attributes {async} // ----- diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -40,10 +40,10 @@ OpBuilder<(ins CArg<"int", "0">:$integer)>]; let parser = [{ foo }]; let printer = [{ bar }]; - let verifier = [{ baz }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Display a graph for debugging purposes.