diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -91,7 +91,6 @@ %0 = amx.tile_zero : vector<16x16xbf16> ``` }]; - let verifier = [{ return ::verify(*this); }]; let results = (outs VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res); let extraClassDeclaration = [{ @@ -100,6 +99,7 @@ } }]; let assemblyFormat = "attr-dict `:` type($res)"; + let hasVerifier = 1; } // @@ -120,7 +120,6 @@ %0 = amx.tile_load %arg0[%c0, %c0] : memref into vector<16x64xi8> ``` }]; - let verifier = [{ return ::verify(*this); }]; let arguments = (ins Arg:$base, Variadic:$indices); let results = (outs @@ -135,6 +134,7 @@ }]; let assemblyFormat = "$base `[` $indices `]` attr-dict `:` " "type($base) `into` type($res)"; + let hasVerifier = 1; } def TileStoreOp : AMX_Op<"tile_store"> { @@ -151,7 +151,6 @@ amx.tile_store %arg1[%c0, %c0], %0 : memref, vector<16x64xi8> ``` }]; - let verifier = [{ return ::verify(*this); }]; let arguments = (ins Arg:$base, Variadic:$indices, VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val); @@ -165,6 +164,7 @@ }]; let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` " "type($base) `,` type($val)"; + let hasVerifier = 1; } // @@ -186,7 +186,6 @@ : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> ``` }]; - let verifier = [{ return ::verify(*this); }]; let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs, VectorOfRankAndType<[2], [F32, BF16]>:$rhs, VectorOfRankAndType<[2], [F32, BF16]>:$acc); @@ -204,6 +203,7 @@ }]; let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` " "type($lhs) `,` type($rhs) `,` type($acc) "; + let hasVerifier = 1; } def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> { @@ -224,7 +224,6 @@ : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> ``` }]; - let verifier = [{ return ::verify(*this); }]; let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs, VectorOfRankAndType<[2], [I32, I8]>:$rhs, VectorOfRankAndType<[2], [I32, I8]>:$acc, @@ -245,6 +244,7 @@ }]; let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` " "type($lhs) `,` type($rhs) `,` type($acc) "; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -31,12 +31,10 @@ Op { // For every affine op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -112,6 +110,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def AffineForOp : Affine_Op<"for", @@ -350,6 +349,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def AffineIfOp : Affine_Op<"if", @@ -473,6 +473,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } class AffineLoadOpBase traits = []> : @@ -538,6 +539,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } class AffineMinMaxOpBase traits = []> : @@ -565,11 +567,11 @@ operands().end()}; } }]; - let verifier = [{ return ::verifyAffineMinMaxOp(*this); }]; let printer = [{ return ::printAffineMinMaxOp(p, *this); }]; let parser = [{ return ::parseAffineMinMaxOp<$cppClass>(parser, result); }]; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def AffineMinOp : AffineMinMaxOpBase<"min", [NoSideEffect]> { @@ -753,6 +755,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } def AffinePrefetchOp : Affine_Op<"prefetch", @@ -832,6 +835,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } class AffineStoreOpBase traits = []> : @@ -896,6 +900,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator, ReturnLike, @@ -921,6 +926,7 @@ ]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let hasVerifier = 1; } def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> { @@ -984,6 +990,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> { @@ -1048,6 +1055,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } #endif // AFFINE_OPS diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -158,7 +158,6 @@ // of strings and signed/unsigned integers (for now) as an artefact of // splitting the Standard dialect. let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result); - let verifier = [{ return ::verify(*this); }]; let builders = [ OpBuilder<(ins "Attribute":$value), @@ -175,6 +174,7 @@ let hasFolder = 1; let assemblyFormat = "attr-dict $value"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -814,7 +814,7 @@ }]; let hasFolder = 1; - let verifier = [{ return verifyExtOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -845,7 +845,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - let verifier = [{ return verifyExtOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -859,8 +859,7 @@ The destination type must to be strictly wider than the source type. When operating on vectors, casts elementwise. }]; - - let verifier = [{ return verifyExtOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -887,7 +886,7 @@ }]; let hasFolder = 1; - let verifier = [{ return verifyTruncateOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -904,7 +903,7 @@ }]; let hasFolder = 1; - let verifier = [{ return verifyTruncateOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -82,9 +82,9 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; let skipDefaultBuilders = 1; + let hasVerifier = 1; let builders = [ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$dependencies, "ValueRange":$operands, @@ -110,10 +110,8 @@ }]; let arguments = (ins Variadic:$operands); - let assemblyFormat = "($operands^ `:` type($operands))? attr-dict"; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Async_AwaitOp : Async_Op<"await"> { @@ -137,6 +135,7 @@ let results = (outs Optional:$result); let skipDefaultBuilders = 1; + let hasVerifier = 1; let builders = [ OpBuilder<(ins "Value":$operand, @@ -155,8 +154,6 @@ type($operand), type($result) ) attr-dict }]; - - let verifier = [{ return ::verify(*this); }]; } def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> { diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -58,7 +58,6 @@ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; let hasFolder = 1; - let verifier = ?; let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -97,8 +96,6 @@ let arguments = (ins Arg:$memref); let results = (outs AnyTensor:$result); - // MemrefToTensor is fully verified by traits. - let verifier = ?; let builders = [ OpBuilder<(ins "Value":$memref), [{ @@ -184,8 +181,6 @@ let arguments = (ins AnyTensor:$tensor); let results = (outs AnyRankedOrUnrankedMemRef:$memref); - // This op is fully verified by traits. - let verifier = ?; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -26,7 +26,6 @@ let arguments = (ins Complex:$lhs, Complex:$rhs); let results = (outs Complex:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; - let verifier = ?; } // Base class for standard unary operations on complex numbers with a @@ -36,7 +35,6 @@ Complex_Op { let arguments = (ins Complex:$complex); let assemblyFormat = "$complex attr-dict `:` type($complex)"; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -102,7 +100,7 @@ let assemblyFormat = "$value attr-dict `:` type($complex)"; let hasFolder = 1; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; let extraClassDeclaration = [{ /// Returns true if a constant operation can be built with the given value diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -24,9 +24,7 @@ // Base class for EmitC dialect ops. class EmitC_Op traits = []> - : Op { - let verifier = "return ::verify(*this);"; -} + : Op; def EmitC_ApplyOp : EmitC_Op<"apply", []> { let summary = "Apply operation"; @@ -54,6 +52,7 @@ let assemblyFormat = [{ $applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results) }]; + let hasVerifier = 1; } def EmitC_CallOp : EmitC_Op<"call", []> { @@ -85,6 +84,7 @@ let assemblyFormat = [{ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) }]; + let hasVerifier = 1; } def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { @@ -113,6 +113,7 @@ let results = (outs AnyType); let hasFolder = 1; + let hasVerifier = 1; } def EmitC_IncludeOp @@ -144,7 +145,6 @@ ); let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = ?; } #endif // MLIR_DIALECT_EMITC_IR_EMITC diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -110,7 +110,6 @@ }]; let assemblyFormat = "attr-dict `:` type($result)"; - let verifier = [{ return success(); }]; } def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [NoSideEffect]>, @@ -126,7 +125,6 @@ }]; let assemblyFormat = "attr-dict `:` type($result)"; - let verifier = [{ return success(); }]; } def GPU_SubgroupSizeOp : GPU_Op<"subgroup_size", [NoSideEffect]>, @@ -142,7 +140,6 @@ }]; let assemblyFormat = "attr-dict `:` type($result)"; - let verifier = [{ return success(); }]; } def GPU_GPUFuncOp : GPU_Op<"func", [ @@ -298,7 +295,6 @@ LogicalResult verifyBody(); }]; - // let verifier = [{ return ::verifFuncOpy(*this); }]; let printer = [{ printGPUFuncOp(p, *this); }]; let parser = [{ return parseGPUFuncOp(parser, result); }]; } @@ -434,7 +430,6 @@ static StringRef getKernelAttrName() { return "kernel"; } }]; - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) $kernel @@ -443,6 +438,7 @@ (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)? custom($operands, type($operands)) attr-dict }]; + let hasVerifier = 1; } def GPU_LaunchOp : GPU_Op<"launch">, @@ -562,8 +558,8 @@ let parser = [{ return parseLaunchOp(parser, result); }]; let printer = [{ printLaunchOp(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def GPU_PrintfOp : GPU_Op<"printf", [MemoryEffects<[MemWrite]>]>, @@ -595,7 +591,7 @@ let builders = [OpBuilder<(ins), [{ // empty}]>]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def GPU_TerminatorOp : GPU_Op<"terminator", [HasParent<"LaunchOp">, @@ -682,9 +678,9 @@ in convergence. }]; let regions = (region AnyRegion:$body); - let verifier = [{ return ::verifyAllReduce(*this); }]; let assemblyFormat = [{ custom($op) $value $body attr-dict `:` functional-type(operands, results) }]; + let hasVerifier = 1; } def GPU_ShuffleOpXor : I32EnumAttrCase<"XOR", 0, "xor">; @@ -822,7 +818,6 @@ }]; let assemblyFormat = "$value attr-dict `:` type($value)"; - let verifier = [{ return success(); }]; } def GPU_WaitOp : GPU_Op<"wait", [GPU_AsyncOpInterface]> { @@ -971,8 +966,8 @@ custom(type($asyncToken), $asyncDependencies) $dst`,` $src `:` type($dst)`,` type($src) attr-dict }]; - let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasVerifier = 1; } def GPU_MemsetOp : GPU_Op<"memset", @@ -1006,8 +1001,6 @@ custom(type($asyncToken), $asyncDependencies) $dst`,` $value `:` type($dst)`,` type($value) attr-dict }]; - // MemsetOp is fully verified by traits. - let verifier = [{ return success(); }]; let hasFolder = 1; } @@ -1048,8 +1041,7 @@ let assemblyFormat = [{ $srcMemref`[`$indices`]` attr-dict `:` type($srcMemref) `->` type($res) }]; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix", @@ -1086,8 +1078,7 @@ let assemblyFormat = [{ $src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref) }]; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", @@ -1125,8 +1116,7 @@ let assemblyFormat = [{ $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB) `->` type($res) }]; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix", diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -351,9 +351,7 @@ constexpr static int kDynamicIndex = std::numeric_limits::min(); }]; let hasFolder = 1; - let verifier = [{ - return ::verify(*this); - }]; + let hasVerifier = 1; } def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { @@ -386,7 +384,7 @@ CArg<"bool", "false">:$isNonTemporal)>]; let parser = [{ return parseLoadOp(parser, result); }]; let printer = [{ printLoadOp(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes { @@ -410,7 +408,7 @@ ]; let parser = [{ return parseStoreOp(parser, result); }]; let printer = [{ printStoreOp(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } // Casts. @@ -494,18 +492,18 @@ build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps, unwindOps, normal, unwind); }]>]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInvokeOp(parser, result); }]; let printer = [{ printInvokeOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_LandingpadOp : LLVM_Op<"landingpad"> { let arguments = (ins UnitAttr:$cleanup, Variadic); let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseLandingpadOp(parser, result); }]; let printer = [{ printLandingpadOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_CallOp : LLVM_Op<"call", @@ -562,9 +560,9 @@ build($_builder, $_state, results, StringAttr::get($_builder.getContext(), callee), operands); }]>]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseCallOp(parser, result); }]; let printer = [{ printCallOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position); @@ -575,9 +573,9 @@ let builders = [ OpBuilder<(ins "Value":$vector, "Value":$position, CArg<"ArrayRef", "{}">:$attrs)>]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseExtractElementOp(parser, result); }]; let printer = [{ printExtractElementOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> { let arguments = (ins LLVM_AnyAggregate:$container, ArrayAttr:$position); @@ -586,10 +584,10 @@ $res = builder.CreateExtractValue($container, extractPosition($position)); }]; let builders = [LLVM_OneResultOpBuilder]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseExtractValueOp(parser, result); }]; let printer = [{ printExtractValueOp(p, *this); }]; let hasFolder = 1; + let hasVerifier = 1; } def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value, @@ -599,9 +597,9 @@ $res = builder.CreateInsertElement($vector, $value, $position); }]; let builders = [LLVM_OneResultOpBuilder]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInsertElementOp(parser, result); }]; let printer = [{ printInsertElementOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> { let arguments = (ins LLVM_AnyAggregate:$container, LLVM_PrimitiveType:$value, @@ -616,9 +614,9 @@ [{ build($_builder, $_state, container.getType(), container, value, position); }]>]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInsertValueOp(parser, result); }]; let printer = [{ printInsertValueOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$v1, LLVM_AnyVector:$v2, ArrayAttr:$mask); @@ -631,16 +629,9 @@ let builders = [ OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask, CArg<"ArrayRef", "{}">:$attrs)>]; - let verifier = [{ - auto type1 = getV1().getType(); - auto type2 = getV2().getType(); - if (::mlir::LLVM::getVectorElementType(type1) != - ::mlir::LLVM::getVectorElementType(type2)) - return emitOpError("expected matching LLVM IR Dialect element types"); - return success(); - }]; let parser = [{ return parseShuffleVectorOp(parser, result); }]; let printer = [{ printShuffleVectorOp(p, *this); }]; + let hasVerifier = 1; } // Misc operations. @@ -718,27 +709,15 @@ builder.CreateRetVoid(); }]; - let verifier = [{ - if (getNumOperands() > 1) - return emitOpError("expects at most 1 operand"); - return success(); - }]; - let parser = [{ return parseReturnOp(parser, result); }]; let printer = [{ printReturnOp(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def LLVM_ResumeOp : LLVM_TerminatorOp<"resume", []> { let arguments = (ins LLVM_Type:$value); string llvmBuilder = [{ builder.CreateResume($value); }]; - let verifier = [{ - if (!isa_and_nonnull(getValue().getDefiningOp())) - return emitOpError("expects landingpad value as operand"); - // No check for personality of function - landingpad op verifies it. - return success(); - }]; - let assemblyFormat = "$value attr-dict `:` type($value)"; + let hasVerifier = 1; } def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> { string llvmBuilder = [{ builder.CreateUnreachable(); }]; @@ -761,7 +740,6 @@ VariadicSuccessor:$caseDestinations ); - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ $value `:` type($value) `,` $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? @@ -769,6 +747,7 @@ $caseOperands, type($caseOperands)) `]` attr-dict }]; + let hasVerifier = 1; let builders = [ OpBuilder<(ins "Value":$value, @@ -924,7 +903,7 @@ }]; let assemblyFormat = "$global_name attr-dict `:` type($res)"; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def LLVM_MetadataOp : LLVM_Op<"metadata", [ @@ -1175,7 +1154,7 @@ let printer = "printGlobalOp(p, *this);"; let parser = "return parseGlobalOp(parser, result);"; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [ @@ -1205,8 +1184,8 @@ ``` }]; - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = "attr-dict"; + let hasVerifier = 1; } def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [ @@ -1234,8 +1213,8 @@ ``` }]; - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = "attr-dict"; + let hasVerifier = 1; } def LLVM_LLVMFuncOp : LLVM_Op<"func", [ @@ -1310,9 +1289,9 @@ LogicalResult verifyType(); }]; - let verifier = [{ return ::verify(*this); }]; let printer = [{ printLLVMFuncOp(p, *this); }]; let parser = [{ return parseLLVMFuncOp(parser, result); }]; + let hasVerifier = 1; } def LLVM_NullOp @@ -1402,8 +1381,8 @@ let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)"; - let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasVerifier = 1; } // Operations that correspond to LLVM intrinsics. With MLIR operation set being @@ -1848,7 +1827,7 @@ }]; let parser = [{ return parseAtomicRMWOp(parser, result); }]; let printer = [{ printAtomicRMWOp(p, *this); }]; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnyInteger, LLVM_AnyPointer]>; @@ -1878,7 +1857,7 @@ }]; let parser = [{ return parseAtomicCmpXchgOp(parser, result); }]; let printer = [{ printAtomicCmpXchgOp(p, *this); }]; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def LLVM_AssumeOp : LLVM_Op<"intr.assume", []> { @@ -1901,7 +1880,7 @@ }]; let parser = [{ return parseFenceOp(parser, result); }]; let printer = [{ printFenceOp(p, *this); }]; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def AsmATT : LLVM_EnumAttrCase< diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -22,12 +22,16 @@ #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc" +namespace mlir { +namespace NVVM { /// Return the element type and number of elements associated with a wmma matrix /// of given chracteristics. This matches the logic in IntrinsicsNVVM.td /// WMMA_REGS structure. std::pair inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, mlir::MLIRContext *context); +} // namespace NVVM +} // namespace mlir ///// Ops ///// #define GET_ATTRDEF_CLASSES diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -131,22 +131,11 @@ $res = createIntrinsicCall(builder, intId, {$dst, $val, $offset, $mask_and_clamp}); }]; - let verifier = [{ - if (!(*this)->getAttrOfType("return_value_and_is_valid")) - return success(); - auto type = getType().dyn_cast(); - auto elementType = (type && type.getBody().size() == 2) - ? type.getBody()[1].dyn_cast() - : nullptr; - if (!elementType || elementType.getWidth() != 1) - return emitError("expected return type to be a two-element struct with " - "i1 as the second element"); - return success(); - }]; let assemblyFormat = [{ $kind $dst `,` $val `,` $offset `,` $mask_and_clamp attr-dict `:` type($val) `->` type($res) }]; + let hasVerifier = 1; } def NVVM_VoteBallotOp : @@ -183,12 +172,8 @@ } createIntrinsicCall(builder, id, {$dst, $src}); }]; - let verifier = [{ - if (size() != 4 && size() != 8 && size() != 16) - return emitError("expected byte size to be either 4, 8 or 16."); - return success(); - }]; let assemblyFormat = "$dst `,` $src `,` $size attr-dict"; + let hasVerifier = 1; } def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> { @@ -220,7 +205,7 @@ builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args); }]; let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } /// Helpers to instantiate different version of wmma intrinsics. @@ -538,7 +523,7 @@ }]; let assemblyFormat = "$ptr `,` $stride attr-dict `:` functional-type($ptr, $res)"; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">, @@ -593,7 +578,7 @@ }]; let assemblyFormat = "$ptr `,` $stride `,` $args attr-dict `:` type($ptr) `,` type($args)"; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } // Base class for all the variants of WMMA mmaOps that may be defined. @@ -647,7 +632,7 @@ }]; let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } #endif // NVVMIR_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -25,12 +25,10 @@ Op { // For every linalg op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -54,8 +52,6 @@ `:` type($result) }]; - let verifier = [{ return ::verify(*this); }]; - let extraClassDeclaration = [{ static StringRef getStaticSizesAttrName() { return "static_sizes"; @@ -127,6 +123,7 @@ ]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, @@ -144,6 +141,7 @@ ``` }]; let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + let hasVerifier = 1; } def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [ @@ -426,6 +424,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Linalg_IndexOp : Linalg_Op<"index", [NoSideEffect]>, @@ -469,6 +468,7 @@ }]; let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; + let hasVerifier = 1; } #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -101,10 +101,9 @@ OpBuilder<(ins "Value":$value, "Value":$output)> ]; - let verifier = [{ return ::verify(*this); }]; - let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -264,10 +263,9 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseGenericOp(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; - let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -28,7 +28,6 @@ class MemRef_Op traits = []> : Op { let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -93,6 +92,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -115,6 +115,7 @@ let results = (outs); let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -162,6 +163,7 @@ memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> ``` }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -205,6 +207,7 @@ an alignment on any convenient boundary compatible with the type will be chosen. }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -253,6 +256,7 @@ let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$bodyRegion); + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -279,11 +283,7 @@ let arguments = (ins Variadic:$results); let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>]; - let assemblyFormat = - [{ attr-dict ($results^ `:` type($results))? }]; - - // No custom verification needed. - let verifier = ?; + let assemblyFormat = "attr-dict ($results^ `:` type($results))?"; } //===----------------------------------------------------------------------===// @@ -355,7 +355,6 @@ let arguments = (ins AnyRankedOrUnrankedMemRef:$source); let results = (outs AnyRankedOrUnrankedMemRef:$dest); let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; - let verifier = "return impl::verifyCastOp(*this, areCastCompatible);"; let builders = [ OpBuilder<(ins "Value":$source, "Type":$destType), [{ impl::buildCastOp($_builder, $_state, source, destType); @@ -370,6 +369,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -408,7 +408,6 @@ let hasCanonicalizer = 1; let hasFolder = 1; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -434,7 +433,6 @@ let arguments = (ins Arg:$memref); let hasFolder = 1; - let verifier = ?; let assemblyFormat = "$memref attr-dict `:` type($memref)"; } @@ -488,6 +486,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -646,6 +645,7 @@ } }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -697,6 +697,7 @@ Value getNumElements() { return numElements(); } }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -757,6 +758,7 @@ return memref().getType().cast(); } }]; + let hasVerifier = 1; } def AtomicYieldOp : MemRef_Op<"atomic_yield", [ @@ -772,6 +774,7 @@ let arguments = (ins AnyType:$result); let assemblyFormat = "$result attr-dict `:` type($result)"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -797,9 +800,6 @@ let arguments = (ins FlatSymbolRefAttr:$name); let results = (outs AnyStaticShapeMemRef:$result); let assemblyFormat = "$name `:` type($result) attr-dict"; - - // `GetGlobalOp` is fully verified by its traits. - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -866,6 +866,7 @@ return !isExternal() && initial_value().getValue().isa(); } }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -939,6 +940,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; } @@ -982,6 +984,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1034,6 +1037,7 @@ let parser = ?; let printer = ?; + let hasVerifier = 1; let builders = [ // Build a ReinterpretCastOp with mixed static and dynamic entries. @@ -1096,7 +1100,6 @@ let arguments = (ins AnyRankedOrUnrankedMemRef:$memref); let results = (outs Index); - let verifier = ?; let hasFolder = 1; let assemblyFormat = "$memref attr-dict `:` type($memref)"; } @@ -1161,6 +1164,7 @@ let assemblyFormat = [{ $source `(` $shape `)` attr-dict `:` functional-type(operands, results) }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1226,6 +1230,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; } @@ -1265,6 +1270,7 @@ ``` }]; let extraClassDeclaration = commonExtraClassDeclaration; + let hasVerifier = 1; } def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> { @@ -1302,6 +1308,7 @@ ``` }]; let extraClassDeclaration = commonExtraClassDeclaration; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1369,6 +1376,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = [{ $value `,` $memref `[` $indices `]` attr-dict `:` type($memref) @@ -1617,6 +1625,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1645,8 +1654,6 @@ let arguments = (ins AnyTensor:$tensor, Arg:$memref); - // TensorStoreOp is fully verified by traits. - let verifier = ?; let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)"; } @@ -1681,6 +1688,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1749,6 +1757,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1796,6 +1805,7 @@ } }]; let hasFolder = 1; + let hasVerifier = 1; } #endif // MEMREF_OPS diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -37,7 +37,6 @@ Op { let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -153,8 +152,6 @@ /// The i-th data operand passed. Value getDataOperand(unsigned i); }]; - - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -225,6 +222,7 @@ ( `attach` `(` $attachOperands^ `:` type($attachOperands) `)` )? $region attr-dict-with-keyword }]; + let hasVerifier = 1; } def OpenACC_TerminatorOp : OpenACC_Op<"terminator", [Terminator]> { @@ -237,8 +235,6 @@ to the enclosing op. }]; - let verifier = ?; - let assemblyFormat = "attr-dict"; } @@ -292,6 +288,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -342,6 +339,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -406,8 +404,7 @@ static StringRef getPrivateKeyword() { return "private"; } static StringRef getReductionKeyword() { return "reduction"; } }]; - - let verifier = [{ return ::verifyLoopOp(*this); }]; + let hasVerifier = 1; } // Yield operation for the acc.loop and acc.parallel operations. @@ -425,8 +422,6 @@ let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - let verifier = ?; - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } @@ -458,6 +453,7 @@ ( `device_num` `(` $deviceNumOperand^ `:` type($deviceNumOperand) `)` )? ( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -488,6 +484,7 @@ ( `device_num` `(` $deviceNumOperand^ `:` type($deviceNumOperand) `)` )? ( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -542,6 +539,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -575,6 +573,7 @@ ( `wait_devnum` `(` $waitDevnum^ `:` type($waitDevnum) `)` )? ( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword }]; + let hasVerifier = 1; } #endif // OPENACC_OPS diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -128,7 +128,7 @@ ]; let parser = [{ return parseParallelOp(parser, result); }]; let printer = [{ return printParallelOp(p, *this); }]; - let verifier = [{ return ::verifyParallelOp(*this); }]; + let hasVerifier = 1; } def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> { @@ -217,7 +217,7 @@ let parser = [{ return parseSectionsOp(parser, result); }]; let printer = [{ return printSectionsOp(p, *this); }]; - let verifier = [{ return verifySectionsOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -336,7 +336,7 @@ }]; let parser = [{ return parseWsLoopOp(parser, result); }]; let printer = [{ return printWsLoopOp(p, *this); }]; - let verifier = [{ return ::verifyWsLoopOp(*this); }]; + let hasVerifier = 1; } def YieldOp : OpenMP_Op<"yield", @@ -457,8 +457,7 @@ let assemblyFormat = [{ $sym_name custom($hint) attr-dict }]; - - let verifier = "return verifyCriticalDeclareOp(*this);"; + let hasVerifier = 1; } @@ -476,8 +475,7 @@ let assemblyFormat = [{ (`(` $name^ `)`)? $region attr-dict }]; - - let verifier = "return ::verifyCriticalOp(*this);"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -540,8 +538,7 @@ ( `depend_vec` `(` $depend_vec_vars^ `:` type($depend_vec_vars) `)` )? attr-dict }]; - - let verifier = "return ::verifyOrderedOp(*this);"; + let hasVerifier = 1; } def OrderedRegionOp : OpenMP_Op<"ordered_region"> { @@ -561,8 +558,7 @@ let regions = (region AnyRegion:$region); let assemblyFormat = [{ ( `simd` $simd^ )? $region attr-dict}]; - - let verifier = "return ::verifyOrderedRegionOp(*this);"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -614,7 +610,7 @@ OptionalAttr:$memory_order); let parser = [{ return parseAtomicReadOp(parser, result); }]; let printer = [{ return printAtomicReadOp(p, *this); }]; - let verifier = [{ return verifyAtomicReadOp(*this); }]; + let hasVerifier = 1; } def AtomicWriteOp : OpenMP_Op<"atomic.write"> { @@ -643,7 +639,7 @@ OptionalAttr:$memory_order); let parser = [{ return parseAtomicWriteOp(parser, result); }]; let printer = [{ return printAtomicWriteOp(p, *this); }]; - let verifier = [{ return verifyAtomicWriteOp(*this); }]; + let hasVerifier = 1; } // TODO: autogenerate from OMP.td in future if possible. @@ -708,7 +704,7 @@ OptionalAttr:$memory_order); let parser = [{ return parseAtomicUpdateOp(parser, result); }]; let printer = [{ return printAtomicUpdateOp(p, *this); }]; - let verifier = [{ return verifyAtomicUpdateOp(*this); }]; + let hasVerifier = 1; } def AtomicCaptureOp : OpenMP_Op<"atomic.capture", @@ -752,7 +748,7 @@ let regions = (region SizedRegion<1>:$region); let parser = [{ return parseAtomicCaptureOp(parser, result); }]; let printer = [{ return printAtomicCaptureOp(p, *this); }]; - let verifier = [{ return verifyAtomicCaptureOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -789,7 +785,6 @@ let regions = (region AnyRegion:$initializerRegion, AnyRegion:$reductionRegion, AnyRegion:$atomicReductionRegion); - let verifier = "return ::verifyReductionDeclareOp(*this);"; let assemblyFormat = "$sym_name `:` $type attr-dict-with-keyword " "`init` $initializerRegion " @@ -804,6 +799,7 @@ return atomicReductionRegion().front().getArgument(0).getType(); } }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -826,7 +822,7 @@ let arguments= (ins AnyType:$operand, OpenMP_PointerLikeType:$accumulator); let assemblyFormat = "$operand `,` $accumulator attr-dict `:` type($accumulator)"; - let verifier = "return ::verifyReductionOp(*this);"; + let hasVerifier = 1; } #endif // OPENMP_OPS diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -26,7 +26,6 @@ : Op { let printer = [{ ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; } //===----------------------------------------------------------------------===// @@ -66,6 +65,7 @@ params.empty() ? ArrayAttr() : $_builder.getArrayAttr(params)); }]>, ]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -118,6 +118,7 @@ $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -164,6 +165,7 @@ build($_builder, $_state, $_builder.getType(), Value(), attr); }]>, ]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -185,7 +187,6 @@ }]; let arguments = (ins PDL_Operation:$operation); let assemblyFormat = "$operation attr-dict"; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -224,6 +225,7 @@ build($_builder, $_state, $_builder.getType(), Value()); }]>, ]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -263,6 +265,7 @@ Value()); }]>, ]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -396,6 +399,7 @@ /// inference. bool hasTypeInference(); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -452,6 +456,7 @@ /// Returns the rewrite operation of this pattern. RewriteOp getRewriter(); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -492,6 +497,7 @@ $operation `with` (`(` $replValues^ `:` type($replValues) `)`)? ($replOperation^)? attr-dict }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -524,7 +530,6 @@ let arguments = (ins PDL_Operation:$parent, I32Attr:$index); let results = (outs PDL_Value:$val); let assemblyFormat = "$index `of` $parent attr-dict"; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -567,6 +572,7 @@ ($index^)? `of` $parent custom(ref($index), type($val)) attr-dict }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -629,6 +635,7 @@ ($body^)? attr-dict-with-keyword }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -657,6 +664,7 @@ let arguments = (ins OptionalAttr:$type); let results = (outs PDL_Type:$result); let assemblyFormat = "attr-dict (`:` $type^)?"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -685,6 +693,7 @@ let arguments = (ins OptionalAttr:$types); let results = (outs PDL_RangeOf:$result); let assemblyFormat = "attr-dict (`:` $types^)?"; + let hasVerifier = 1; } #endif // MLIR_DIALECT_PDL_IR_PDLOPS diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -75,19 +75,6 @@ PDLInterp_Op { let successors = (successor AnySuccessor:$defaultDest, VariadicSuccessor:$cases); - - let verifier = [{ - // Verify that the number of case destinations matches the number of case - // values. - size_t numDests = cases().size(); - size_t numValues = caseValues().size(); - if (numDests != numValues) { - return emitOpError("expected number of cases to match the number of case " - "values, got ") - << numDests << " but expected " << numValues; - } - return success(); - }]; } //===----------------------------------------------------------------------===// @@ -638,7 +625,7 @@ }]; let parser = [{ return ::parseForEachOp(parser, result); }]; let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1078,6 +1065,7 @@ build($_builder, $_state, attribute, $_builder.getArrayAttr(caseValues), defaultDest, dests); }]>]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1111,6 +1099,7 @@ build($_builder, $_state, operation, $_builder.getI32VectorAttr(counts), defaultDest, dests); }]>]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1148,6 +1137,7 @@ defaultDest, dests); }]>, ]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1181,6 +1171,7 @@ build($_builder, $_state, operation, $_builder.getI32VectorAttr(counts), defaultDest, dests); }]>]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1218,6 +1209,7 @@ let extraClassDeclaration = [{ auto getCaseTypes() { return caseValues().getAsValueRange(); } }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1259,6 +1251,7 @@ let extraClassDeclaration = [{ auto getCaseTypes() { return caseValues().getAsRange(); } }]; + let hasVerifier = 1; } #endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/QuantOps.td --- a/mlir/include/mlir/Dialect/Quant/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/QuantOps.td @@ -103,7 +103,7 @@ StrAttr:$logical_kernel); let results = (outs Variadic:$outputs); let regions = (region SizedRegion<1>:$body); - let verifier = [{ return verifyRegionOp(*this); }]; + let hasVerifier = 1; } def quant_ReturnOp : quant_Op<"return", [Terminator]> { @@ -227,43 +227,7 @@ OptionalAttr:$axisStats, OptionalAttr:$axis); let results = (outs quant_RealValueType); - - let verifier = [{ - auto tensorArg = arg().getType().dyn_cast(); - if (!tensorArg) return emitOpError("arg needs to be tensor type."); - - // Verify layerStats attribute. - { - auto layerStatsType = layerStats().getType(); - if (!layerStatsType.getElementType().isa()) { - return emitOpError( - "layerStats must have a floating point element type"); - } - if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { - return emitOpError("layerStats must have shape [2]"); - } - } - // Verify axisStats (optional) attribute. - if (axisStats()) { - if (!axis()) return emitOpError("axis must be specified for axisStats"); - - auto shape = tensorArg.getShape(); - auto argSliceSize = std::accumulate(std::next(shape.begin(), - *axis()), shape.end(), 1, std::multiplies()); - - auto axisStatsType = axisStats()->getType(); - if (!axisStatsType.getElementType().isa()) { - return emitOpError("axisStats must have a floating point element type"); - } - if (axisStatsType.getRank() != 2 || - axisStatsType.getDimSize(1) != 2 || - axisStatsType.getDimSize(0) != argSliceSize) { - return emitOpError("axisStats must have shape [N,2] " - "where N = the slice size defined by the axis dim"); - } - } - return success(); - }]; + let hasVerifier = 1; } def quant_CoupledRefOp : quant_Op<"coupled_ref", [SameOperandsAndResultType]> { diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -29,12 +29,10 @@ Op { // For every standard op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -56,9 +54,6 @@ let assemblyFormat = [{ `(` $condition `)` attr-dict ($args^ `:` type($args))? }]; - - // Override the default verifier, everything is checked by traits. - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -114,6 +109,7 @@ let hasCanonicalizer = 1; let hasFolder = 0; + let hasVerifier = 1; } def ForOp : SCF_Op<"for", @@ -312,6 +308,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def IfOp : SCF_Op<"if", @@ -404,6 +401,7 @@ }]; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def ParallelOp : SCF_Op<"parallel", @@ -485,6 +483,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> { @@ -533,6 +532,7 @@ let arguments = (ins AnyType:$operand); let regions = (region SizedRegion<1>:$reductionOperator); + let hasVerifier = 1; } def ReduceReturnOp : @@ -551,6 +551,7 @@ let arguments = (ins AnyType:$result); let assemblyFormat = "$result attr-dict `:` type($result)"; + let hasVerifier = 1; } def WhileOp : SCF_Op<"while", @@ -683,6 +684,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator, @@ -706,10 +708,6 @@ let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }]; - - // Override default verifier (defined in SCF_Op), no custom verification - // needed. - let verifier = ?; } #endif // MLIR_DIALECT_SCF_SCFOPS diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -48,8 +48,6 @@ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; - let verifier = [{ return verifySizeOrIndexOp(*this); }]; - let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by // InferTypeOpInterface @@ -57,6 +55,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> { @@ -102,7 +101,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Shape_ConstShapeOp : Shape_Op<"const_shape", @@ -184,8 +183,8 @@ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -323,7 +322,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -377,7 +376,7 @@ }]; let hasFolder = 1; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let hasVerifier = 1; } def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> { @@ -518,8 +517,8 @@ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -545,7 +544,7 @@ let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)"; let hasFolder = 1; - let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by // InferTypeOpInterface @@ -595,7 +594,7 @@ let builders = [OpBuilder<(ins "Value":$shape, "ValueRange":$initVals)>]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -614,9 +613,9 @@ let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; - let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns when two result types are compatible for this op; method used by @@ -724,8 +723,8 @@ [{ build($_builder, $_state, llvm::None); }]> ]; - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let hasVerifier = 1; } // TODO: Add Ops: if_static, if_ranked @@ -859,8 +858,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Shape_AssumingOp : Shape_Op<"assuming", [ @@ -882,7 +880,6 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; let extraClassDeclaration = [{ // Inline the region into the region containing the AssumingOp and delete @@ -898,6 +895,7 @@ ]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", @@ -959,7 +957,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> { diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -21,7 +21,6 @@ class SparseTensor_Op traits = []> : Op { let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -50,6 +49,7 @@ ``` }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; + let hasVerifier = 1; } def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>, @@ -72,6 +72,7 @@ ``` }]; let assemblyFormat = "`[` $sizes `]` attr-dict `:` type($result)"; + let hasVerifier = 1; } def SparseTensor_ConvertOp : SparseTensor_Op<"convert", @@ -113,6 +114,7 @@ }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; let hasFolder = 1; + let hasVerifier = 1; } def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>, @@ -137,6 +139,7 @@ }]; let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)" " `to` type($result)"; + let hasVerifier = 1; } def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>, @@ -161,6 +164,7 @@ }]; let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)" " `to` type($result)"; + let hasVerifier = 1; } def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>, @@ -183,6 +187,7 @@ ``` }]; let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -217,6 +222,7 @@ }]; let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`" " type($tensor) `,` type($indices) `,` type($value)"; + let hasVerifier = 1; } def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>, @@ -258,6 +264,7 @@ }]; let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)" " `,` type($filled) `,` type($added) `,` type($count)"; + let hasVerifier = 1; } def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>, @@ -292,6 +299,7 @@ " $added `,` $count attr-dict `:` type($tensor) `,`" " type($indices) `,` type($values) `,` type($filled) `,`" " type($added) `,` type($count)"; + let hasVerifier = 1; } def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>, @@ -324,6 +332,7 @@ ``` }]; let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)"; + let hasVerifier = 1; } def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>, @@ -349,6 +358,7 @@ ``` }]; let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; + let hasVerifier = 1; } def SparseTensor_OutOp : SparseTensor_Op<"out", []>, @@ -369,6 +379,7 @@ ``` }]; let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)"; + let hasVerifier = 1; } #endif // SPARSETENSOR_OPS diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -35,12 +35,10 @@ Op { // For every standard op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -66,10 +64,6 @@ let arguments = (ins I1:$arg, StrAttr:$msg); let assemblyFormat = "$arg `,` $msg attr-dict"; - - // AssertOp is fully verified by its traits. - let verifier = ?; - let hasCanonicalizeMethod = 1; } @@ -107,9 +101,6 @@ $_state.addOperands(destOperands); }]>]; - // BranchOp is fully verified by traits. - let verifier = ?; - let extraClassDeclaration = [{ void setDest(Block *block); @@ -189,7 +180,6 @@ let assemblyFormat = [{ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) }]; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -250,7 +240,6 @@ CallInterfaceCallable getCallableForCallee() { return getCallee(); } }]; - let verifier = ?; let hasCanonicalizeMethod = 1; let assemblyFormat = @@ -311,9 +300,6 @@ falseOperands); }]>]; - // CondBranchOp is fully verified by traits. - let verifier = ?; - let extraClassDeclaration = [{ // These are the indices into the dests list. enum { trueIndex = 0, falseIndex = 1 }; @@ -434,6 +420,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -466,6 +453,7 @@ [{ build($_builder, $_state, llvm::None); }]>]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -516,6 +504,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -562,6 +551,7 @@ [{ build($_builder, $_state, aggregateType, element); }]>]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$input attr-dict `:` type($aggregate)"; } @@ -652,6 +642,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } #endif // STANDARD_OPS diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -20,7 +20,6 @@ class Tensor_Op traits = []> : Op { let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -59,7 +58,6 @@ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; let hasCanonicalizer = 1; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -111,6 +109,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -151,6 +150,7 @@ }]>]; let hasFolder = 1; + let hasVerifier = 1; } @@ -303,6 +303,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -339,9 +340,6 @@ let assemblyFormat = "$elements attr-dict `:` type($result)"; - // This op is fully verified by its traits. - let verifier = ?; - let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "Type":$resultType, "ValueRange":$elements)>, @@ -394,6 +392,7 @@ ]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -445,6 +444,7 @@ }]>]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -564,6 +564,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -586,7 +587,6 @@ let arguments = (ins AnyTensor:$tensor); let results = (outs Index); - let verifier = ?; let hasFolder = 1; let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; } @@ -650,6 +650,7 @@ let assemblyFormat = [{ $source `(` $shape `)` attr-dict `:` functional-type(operands, results) }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -718,6 +719,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; } @@ -748,6 +750,7 @@ ``` }]; let extraClassDeclaration = commonExtraClassDeclaration; + let hasVerifier = 1; } def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { @@ -776,6 +779,7 @@ ``` }]; let extraClassDeclaration = commonExtraClassDeclaration; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -961,6 +965,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } @@ -984,7 +989,6 @@ // Dummy builder to appease code in templated ensureTerminator that // GenerateOp's auto-generated parser calls. let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - let verifier = ?; } #endif // TENSOR_OPS diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -82,8 +82,7 @@ ); let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; - - let verifier = [{ return verifyAveragePoolOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -116,8 +115,7 @@ ); let builders = [Tosa_ConvOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -149,8 +147,7 @@ ); let builders = [Tosa_ConvOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -183,8 +180,7 @@ ); let builders = [Tosa_ConvOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -212,8 +208,7 @@ ); let builders = [Tosa_FCOpQuantInfoBuilder]; - - let verifier = [{ return verifyConvOp(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -29,12 +29,10 @@ Op { // For every vector op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -255,6 +253,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_ReductionOp : @@ -290,6 +289,7 @@ return vector().getType().cast(); } }]; + let hasVerifier = 1; } def Vector_MultiDimReductionOp : @@ -373,6 +373,7 @@ let assemblyFormat = "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)"; let hasFolder = 1; + let hasVerifier = 1; } def Vector_BroadcastOp : @@ -420,6 +421,7 @@ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_ShuffleOp : @@ -475,6 +477,7 @@ return vector().getType().cast(); } }]; + let hasVerifier = 1; } def Vector_ExtractElementOp : @@ -521,6 +524,7 @@ return vector().getType().cast(); } }]; + let hasVerifier = 1; } def Vector_ExtractOp : @@ -555,6 +559,7 @@ }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_ExtractMapOp : @@ -623,6 +628,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } def Vector_FMAOp : @@ -648,8 +654,6 @@ %3 = vector.fma %0, %1, %2: vector<8x16xf32> ``` }]; - // Fully specified by traits. - let verifier = ?; let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)"; let builders = [ OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc), @@ -706,7 +710,7 @@ return dest().getType().cast(); } }]; - + let hasVerifier = 1; } def Vector_InsertOp : @@ -749,6 +753,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_InsertMapOp : @@ -816,6 +821,7 @@ $vector `,` $dest `[` $ids `]` attr-dict `:` type($vector) `into` type($result) }]; + let hasVerifier = 1; } def Vector_InsertStridedSliceOp : @@ -873,6 +879,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } def Vector_OuterProductOp : @@ -960,6 +967,7 @@ return CombiningKind::ADD; } }]; + let hasVerifier = 1; } // TODO: Add transformation which decomposes ReshapeOp into an optimized @@ -1081,6 +1089,7 @@ $vector `,` `[` $input_shape `]` `,` `[` $output_shape `]` `,` $fixed_vector_sizes attr-dict `:` type($vector) `to` type($result) }]; + let hasVerifier = 1; } def Vector_ExtractStridedSliceOp : @@ -1133,6 +1142,7 @@ }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)"; } @@ -1340,6 +1350,7 @@ ]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_TransferWriteOp : @@ -1477,6 +1488,7 @@ ]; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_LoadOp : Vector_Op<"load"> { @@ -1552,6 +1564,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)"; @@ -1628,6 +1641,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict " "`:` type($base) `,` type($valueToStore)"; @@ -1687,6 +1701,7 @@ "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_MaskedStoreOp : @@ -1740,6 +1755,7 @@ "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_GatherOp : @@ -1805,6 +1821,7 @@ "type($index_vec) `,` type($mask) `,` type($pass_thru) " "`into` type($result)"; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_ScatterOp : @@ -1867,6 +1884,7 @@ "$mask `,` $valueToStore attr-dict `:` type($base) `,` " "type($index_vec) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_ExpandLoadOp : @@ -1925,6 +1943,7 @@ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_CompressStoreOp : @@ -1980,6 +1999,7 @@ "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` " "type($base) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_ShapeCastOp : @@ -2031,6 +2051,7 @@ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_BitCastOp : @@ -2070,6 +2091,7 @@ }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; let hasFolder = 1; + let hasVerifier = 1; } def Vector_TypeCastOp : @@ -2116,6 +2138,7 @@ let assemblyFormat = [{ $memref attr-dict `:` type($memref) `to` type($result) }]; + let hasVerifier = 1; } def Vector_ConstantMaskOp : @@ -2157,6 +2180,7 @@ static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; } }]; let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)"; + let hasVerifier = 1; } def Vector_CreateMaskOp : @@ -2194,6 +2218,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; let assemblyFormat = "$operands attr-dict `:` type(results)"; } @@ -2245,6 +2270,7 @@ }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_PrintOp : @@ -2272,7 +2298,6 @@ newline). ``` }]; - let verifier = ?; let extraClassDeclaration = [{ Type getPrintType() { return source().getType(); @@ -2348,7 +2373,6 @@ lhs.getType().cast().getElementType())); }]>, ]; - let verifier = ?; let assemblyFormat = "$lhs `,` $rhs attr-dict " "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; } @@ -2393,7 +2417,6 @@ : (vector<16xf32>) -> vector<16xf32> ``` }]; - let verifier = ?; let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)"; } @@ -2426,7 +2449,6 @@ }]; let results = (outs Index:$res); let assemblyFormat = "attr-dict"; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -2485,6 +2507,7 @@ let assemblyFormat = "$kind `,` $source `,` $initial_value attr-dict `:` " "type($source) `,` type($initial_value) "; + let hasVerifier = 1; } #endif // VECTOR_OPS diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -76,7 +76,6 @@ with their respective bit set in writemask `k`) to `dst`, and pass through the remaining elements from `src`. }]; - let verifier = [{ return ::verify(*this); }]; let arguments = (ins VectorOfLengthAndType<[16, 8], [I1]>:$k, VectorOfLengthAndType<[16, 8], @@ -88,6 +87,7 @@ [F32, I32, F64, I64]>:$dst); let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict" " `:` type($dst) (`,` type($src)^)?"; + let hasVerifier = 1; } def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [ diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -157,7 +157,7 @@ }]; let parser = [{ return ::parseFuncOp(parser, result); }]; let printer = [{ return ::print(*this, p); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -221,7 +221,7 @@ return "builtin"; } }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; // We need to ensure the block inside the region is properly terminated; // the auto-generated builders do not guarantee that. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2470,12 +2470,10 @@ // For every such op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; code extraBaseClassDeclaration = [{ diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -358,22 +358,19 @@ } // namespace -namespace mlir { - /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. -LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) { +LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { NVVM::MMAFrag frag = convertOperand(type.getOperand()); NVVM::MMATypes eltType = getElementType(type); std::pair typeInfo = - inferMMAType(eltType, frag, type.getContext()); + NVVM::inferMMAType(eltType, frag, type.getContext()); return LLVM::LLVMStructType::getLiteral( type.getContext(), SmallVector(typeInfo.second, typeInfo.first)); } -void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns) { +void mlir::populateGpuWMMAToNVVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.insert(converter); } -} // namespace mlir diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -52,53 +52,55 @@ return success(); } -static LogicalResult verify(amx::TileZeroOp op) { - return verifyTileSize(op, op.getVectorType()); +LogicalResult amx::TileZeroOp::verify() { + return verifyTileSize(*this, getVectorType()); } -static LogicalResult verify(amx::TileLoadOp op) { - unsigned rank = op.getMemRefType().getRank(); - if (llvm::size(op.indices()) != rank) - return op.emitOpError("requires ") << rank << " indices"; - return verifyTileSize(op, op.getVectorType()); +LogicalResult amx::TileLoadOp::verify() { + unsigned rank = getMemRefType().getRank(); + if (indices().size() != rank) + return emitOpError("requires ") << rank << " indices"; + return verifyTileSize(*this, getVectorType()); } -static LogicalResult verify(amx::TileStoreOp op) { - unsigned rank = op.getMemRefType().getRank(); - if (llvm::size(op.indices()) != rank) - return op.emitOpError("requires ") << rank << " indices"; - return verifyTileSize(op, op.getVectorType()); +LogicalResult amx::TileStoreOp::verify() { + unsigned rank = getMemRefType().getRank(); + if (indices().size() != rank) + return emitOpError("requires ") << rank << " indices"; + return verifyTileSize(*this, getVectorType()); } -static LogicalResult verify(amx::TileMulFOp op) { - VectorType aType = op.getLhsVectorType(); - VectorType bType = op.getRhsVectorType(); - VectorType cType = op.getVectorType(); - if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) || - failed(verifyTileSize(op, cType)) || - failed(verifyMultShape(op, aType, bType, cType, 1))) +LogicalResult amx::TileMulFOp::verify() { + VectorType aType = getLhsVectorType(); + VectorType bType = getRhsVectorType(); + VectorType cType = getVectorType(); + if (failed(verifyTileSize(*this, aType)) || + failed(verifyTileSize(*this, bType)) || + failed(verifyTileSize(*this, cType)) || + failed(verifyMultShape(*this, aType, bType, cType, 1))) return failure(); Type ta = aType.getElementType(); Type tb = bType.getElementType(); Type tc = cType.getElementType(); if (!ta.isBF16() || !tb.isBF16() || !tc.isF32()) - return op.emitOpError("unsupported type combination"); + return emitOpError("unsupported type combination"); return success(); } -static LogicalResult verify(amx::TileMulIOp op) { - VectorType aType = op.getLhsVectorType(); - VectorType bType = op.getRhsVectorType(); - VectorType cType = op.getVectorType(); - if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) || - failed(verifyTileSize(op, cType)) || - failed(verifyMultShape(op, aType, bType, cType, 2))) +LogicalResult amx::TileMulIOp::verify() { + VectorType aType = getLhsVectorType(); + VectorType bType = getRhsVectorType(); + VectorType cType = getVectorType(); + if (failed(verifyTileSize(*this, aType)) || + failed(verifyTileSize(*this, bType)) || + failed(verifyTileSize(*this, cType)) || + failed(verifyMultShape(*this, aType, bType, cType, 2))) return failure(); Type ta = aType.getElementType(); Type tb = bType.getElementType(); Type tc = cType.getElementType(); if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) - return op.emitOpError("unsupported type combination"); + return emitOpError("unsupported type combination"); return success(); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -524,18 +524,18 @@ p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"map"}); } -static LogicalResult verify(AffineApplyOp op) { +LogicalResult AffineApplyOp::verify() { // Check input and output dimensions match. - auto map = op.map(); + AffineMap affineMap = map(); // Verify that operand count matches affine map dimension and symbol count. - if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols()) - return op.emitOpError( + if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols()) + return emitOpError( "operand count and affine map dimension and symbol count must match"); // Verify that the map only produces one result. - if (map.getNumResults() != 1) - return op.emitOpError("mapping must produce one value"); + if (affineMap.getNumResults() != 1) + return emitOpError("mapping must produce one value"); return success(); } @@ -1306,41 +1306,38 @@ bodyBuilder); } -static LogicalResult verify(AffineForOp op) { +LogicalResult AffineForOp::verify() { // Check that the body defines as single block argument for the induction // variable. - auto *body = op.getBody(); + auto *body = getBody(); if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex()) - return op.emitOpError( - "expected body to have a single index argument for the " - "induction variable"); + return emitOpError("expected body to have a single index argument for the " + "induction variable"); // Verify that the bound operands are valid dimension/symbols. /// Lower bound. - if (op.getLowerBoundMap().getNumInputs() > 0) - if (failed( - verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), - op.getLowerBoundMap().getNumDims()))) + if (getLowerBoundMap().getNumInputs() > 0) + if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(), + getLowerBoundMap().getNumDims()))) return failure(); /// Upper bound. - if (op.getUpperBoundMap().getNumInputs() > 0) - if (failed( - verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), - op.getUpperBoundMap().getNumDims()))) + if (getUpperBoundMap().getNumInputs() > 0) + if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(), + getUpperBoundMap().getNumDims()))) return failure(); - unsigned opNumResults = op.getNumResults(); + unsigned opNumResults = getNumResults(); if (opNumResults == 0) return success(); // If ForOp defines values, check that the number and types of the defined // values match ForOp initial iter operands and backedge basic block // arguments. - if (op.getNumIterOperands() != opNumResults) - return op.emitOpError( + if (getNumIterOperands() != opNumResults) + return emitOpError( "mismatch between the number of loop-carried values and results"); - if (op.getNumRegionIterArgs() != opNumResults) - return op.emitOpError( + if (getNumRegionIterArgs() != opNumResults) + return emitOpError( "mismatch between the number of basic block args and results"); return success(); @@ -2063,23 +2060,22 @@ }; } // namespace -static LogicalResult verify(AffineIfOp op) { +LogicalResult AffineIfOp::verify() { // Verify that we have a condition attribute. + // FIXME: This should be specified in the arguments list in ODS. auto conditionAttr = - op->getAttrOfType(op.getConditionAttrName()); + (*this)->getAttrOfType(getConditionAttrName()); if (!conditionAttr) - return op.emitOpError( - "requires an integer set attribute named 'condition'"); + return emitOpError("requires an integer set attribute named 'condition'"); // Verify that there are enough operands for the condition. IntegerSet condition = conditionAttr.getValue(); - if (op.getNumOperands() != condition.getNumInputs()) - return op.emitOpError( - "operand count and condition integer set dimension and " - "symbol count must match"); + if (getNumOperands() != condition.getNumInputs()) + return emitOpError("operand count and condition integer set dimension and " + "symbol count must match"); // Verify that the operands are valid dimension/symbols. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(), + if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), condition.getNumDims()))) return failure(); @@ -2325,16 +2321,16 @@ return success(); } -LogicalResult verify(AffineLoadOp op) { - auto memrefType = op.getMemRefType(); - if (op.getType() != memrefType.getElementType()) - return op.emitOpError("result type must match element type of memref"); +LogicalResult AffineLoadOp::verify() { + auto memrefType = getMemRefType(); + if (getType() != memrefType.getElementType()) + return emitOpError("result type must match element type of memref"); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 1))) + getOperation(), + (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 1))) return failure(); return success(); @@ -2413,18 +2409,18 @@ p << " : " << op.getMemRefType(); } -LogicalResult verify(AffineStoreOp op) { +LogicalResult AffineStoreOp::verify() { // The value to store must have the same type as memref element type. - auto memrefType = op.getMemRefType(); - if (op.getValueToStore().getType() != memrefType.getElementType()) - return op.emitOpError( + auto memrefType = getMemRefType(); + if (getValueToStore().getType() != memrefType.getElementType()) + return emitOpError( "value to store must have the same type as memref element type"); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 2))) + getOperation(), + (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 2))) return failure(); return success(); @@ -2672,6 +2668,8 @@ context); } +LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); } + //===----------------------------------------------------------------------===// // AffineMaxOp //===----------------------------------------------------------------------===// @@ -2691,6 +2689,8 @@ context); } +LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); } + //===----------------------------------------------------------------------===// // AffinePrefetchOp //===----------------------------------------------------------------------===// @@ -2764,24 +2764,24 @@ p << " : " << op.getMemRefType(); } -static LogicalResult verify(AffinePrefetchOp op) { - auto mapAttr = op->getAttrOfType(op.getMapAttrName()); +LogicalResult AffinePrefetchOp::verify() { + auto mapAttr = (*this)->getAttrOfType(getMapAttrName()); if (mapAttr) { AffineMap map = mapAttr.getValue(); - if (map.getNumResults() != op.getMemRefType().getRank()) - return op.emitOpError("affine.prefetch affine map num results must equal" - " memref rank"); - if (map.getNumInputs() + 1 != op.getNumOperands()) - return op.emitOpError("too few operands"); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.prefetch affine map num results must equal" + " memref rank"); + if (map.getNumInputs() + 1 != getNumOperands()) + return emitOpError("too few operands"); } else { - if (op.getNumOperands() != 1) - return op.emitOpError("too few operands"); + if (getNumOperands() != 1) + return emitOpError("too few operands"); } - Region *scope = getAffineScope(op); - for (auto idx : op.getMapOperands()) { + Region *scope = getAffineScope(*this); + for (auto idx : getMapOperands()) { if (!isValidAffineIndexOperand(idx, scope)) - return op.emitOpError("index must be a dimension or symbol identifier"); + return emitOpError("index must be a dimension or symbol identifier"); } return success(); } @@ -3018,53 +3018,52 @@ stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); } -static LogicalResult verify(AffineParallelOp op) { - auto numDims = op.getNumDims(); - if (op.lowerBoundsGroups().getNumElements() != numDims || - op.upperBoundsGroups().getNumElements() != numDims || - op.steps().size() != numDims || - op.getBody()->getNumArguments() != numDims) { - return op.emitOpError() - << "the number of region arguments (" - << op.getBody()->getNumArguments() - << ") and the number of map groups for lower (" - << op.lowerBoundsGroups().getNumElements() << ") and upper bound (" - << op.upperBoundsGroups().getNumElements() - << "), and the number of steps (" << op.steps().size() - << ") must all match"; +LogicalResult AffineParallelOp::verify() { + auto numDims = getNumDims(); + if (lowerBoundsGroups().getNumElements() != numDims || + upperBoundsGroups().getNumElements() != numDims || + steps().size() != numDims || getBody()->getNumArguments() != numDims) { + return emitOpError() << "the number of region arguments (" + << getBody()->getNumArguments() + << ") and the number of map groups for lower (" + << lowerBoundsGroups().getNumElements() + << ") and upper bound (" + << upperBoundsGroups().getNumElements() + << "), and the number of steps (" << steps().size() + << ") must all match"; } unsigned expectedNumLBResults = 0; - for (APInt v : op.lowerBoundsGroups()) + for (APInt v : lowerBoundsGroups()) expectedNumLBResults += v.getZExtValue(); - if (expectedNumLBResults != op.lowerBoundsMap().getNumResults()) - return op.emitOpError() << "expected lower bounds map to have " - << expectedNumLBResults << " results"; + if (expectedNumLBResults != lowerBoundsMap().getNumResults()) + return emitOpError() << "expected lower bounds map to have " + << expectedNumLBResults << " results"; unsigned expectedNumUBResults = 0; - for (APInt v : op.upperBoundsGroups()) + for (APInt v : upperBoundsGroups()) expectedNumUBResults += v.getZExtValue(); - if (expectedNumUBResults != op.upperBoundsMap().getNumResults()) - return op.emitOpError() << "expected upper bounds map to have " - << expectedNumUBResults << " results"; + if (expectedNumUBResults != upperBoundsMap().getNumResults()) + return emitOpError() << "expected upper bounds map to have " + << expectedNumUBResults << " results"; - if (op.reductions().size() != op.getNumResults()) - return op.emitOpError("a reduction must be specified for each output"); + if (reductions().size() != getNumResults()) + return emitOpError("a reduction must be specified for each output"); // Verify reduction ops are all valid - for (Attribute attr : op.reductions()) { + for (Attribute attr : reductions()) { auto intAttr = attr.dyn_cast(); if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt())) - return op.emitOpError("invalid reduction attribute"); + return emitOpError("invalid reduction attribute"); } // Verify that the bound operands are valid dimension/symbols. /// Lower bounds. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(), - op.lowerBoundsMap().getNumDims()))) + if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(), + lowerBoundsMap().getNumDims()))) return failure(); /// Upper bounds. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(), - op.upperBoundsMap().getNumDims()))) + if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(), + upperBoundsMap().getNumDims()))) return failure(); return success(); } @@ -3412,20 +3411,19 @@ // AffineYieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AffineYieldOp op) { - auto *parentOp = op->getParentOp(); +LogicalResult AffineYieldOp::verify() { + auto *parentOp = (*this)->getParentOp(); auto results = parentOp->getResults(); - auto operands = op.getOperands(); + auto operands = getOperands(); if (!isa(parentOp)) - return op.emitOpError() << "only terminates affine.if/for/parallel regions"; - if (parentOp->getNumResults() != op.getNumOperands()) - return op.emitOpError() << "parent of yield must have same number of " - "results as the yield operands"; + return emitOpError() << "only terminates affine.if/for/parallel regions"; + if (parentOp->getNumResults() != getNumOperands()) + return emitOpError() << "parent of yield must have same number of " + "results as the yield operands"; for (auto it : llvm::zip(results, operands)) { if (std::get<0>(it).getType() != std::get<1>(it).getType()) - return op.emitOpError() - << "types mismatch between yield op and its parent"; + return emitOpError() << "types mismatch between yield op and its parent"; } return success(); @@ -3516,17 +3514,16 @@ return success(); } -static LogicalResult verify(AffineVectorLoadOp op) { - MemRefType memrefType = op.getMemRefType(); +LogicalResult AffineVectorLoadOp::verify() { + MemRefType memrefType = getMemRefType(); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 1))) + getOperation(), + (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 1))) return failure(); - if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, - op.getVectorType()))) + if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType()))) return failure(); return success(); @@ -3599,17 +3596,15 @@ p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType(); } -static LogicalResult verify(AffineVectorStoreOp op) { - MemRefType memrefType = op.getMemRefType(); +LogicalResult AffineVectorStoreOp::verify() { + MemRefType memrefType = getMemRefType(); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 2))) + *this, (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 2))) return failure(); - if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, - op.getVectorType()))) + if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType()))) return failure(); return success(); diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -107,19 +107,19 @@ /// TODO: disallow arith.constant to return anything other than signless integer /// or float like. -static LogicalResult verify(arith::ConstantOp op) { - auto type = op.getType(); +LogicalResult arith::ConstantOp::verify() { + auto type = getType(); // The value's type must match the return type. - if (op.getValue().getType() != type) { - return op.emitOpError() << "value type " << op.getValue().getType() - << " must match return type: " << type; + if (getValue().getType() != type) { + return emitOpError() << "value type " << getValue().getType() + << " must match return type: " << type; } // Integer values must be signless. if (type.isa() && !type.cast().isSignless()) - return op.emitOpError("integer return type must be signless"); + return emitOpError("integer return type must be signless"); // Any float or elements attribute are acceptable. - if (!op.getValue().isa()) { - return op.emitOpError( + if (!getValue().isa()) { + return emitOpError( "value must be an integer, float, or elements attribute"); } return success(); @@ -886,6 +886,10 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::ExtUIOp::verify() { + return verifyExtOp(*this); +} + //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// @@ -912,6 +916,10 @@ patterns.insert(context); } +LogicalResult arith::ExtSIOp::verify() { + return verifyExtOp(*this); +} + //===----------------------------------------------------------------------===// // ExtFOp //===----------------------------------------------------------------------===// @@ -920,6 +928,8 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::ExtFOp::verify() { return verifyExtOp(*this); } + //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// @@ -954,6 +964,10 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::TruncIOp::verify() { + return verifyTruncateOp(*this); +} + //===----------------------------------------------------------------------===// // TruncFOp //===----------------------------------------------------------------------===// @@ -983,6 +997,10 @@ return checkWidthChangeCast(inputs, outputs); } +LogicalResult arith::TruncFOp::verify() { + return verifyTruncateOp(*this); +} + //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -33,17 +33,17 @@ // YieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(YieldOp op) { +LogicalResult YieldOp::verify() { // Get the underlying value types from async values returned from the // parent `async.execute` operation. - auto executeOp = op->getParentOfType(); + auto executeOp = (*this)->getParentOfType(); auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { return result.getType().cast().getValueType(); }); - if (op.getOperandTypes() != types) - return op.emitOpError("operand types do not match the types returned from " - "the parent ExecuteOp"); + if (getOperandTypes() != types) + return emitOpError("operand types do not match the types returned from " + "the parent ExecuteOp"); return success(); } @@ -228,16 +228,16 @@ return success(); } -static LogicalResult verify(ExecuteOp op) { +LogicalResult ExecuteOp::verify() { // Unwrap async.execute value operands types. - auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { + auto unwrappedTypes = llvm::map_range(operands(), [](Value operand) { return operand.getType().cast().getValueType(); }); // Verify that unwrapped argument types matches the body region arguments. - if (op.body().getArgumentTypes() != unwrappedTypes) - return op.emitOpError("async body region argument types do not match the " - "execute operation arguments types"); + if (body().getArgumentTypes() != unwrappedTypes) + return emitOpError("async body region argument types do not match the " + "execute operation arguments types"); return success(); } @@ -303,19 +303,19 @@ p << operandType; } -static LogicalResult verify(AwaitOp op) { - Type argType = op.operand().getType(); +LogicalResult AwaitOp::verify() { + Type argType = operand().getType(); // Awaiting on a token does not have any results. - if (argType.isa() && !op.getResultTypes().empty()) - return op.emitOpError("awaiting on a token must have empty result"); + if (argType.isa() && !getResultTypes().empty()) + return emitOpError("awaiting on a token must have empty result"); // Awaiting on a value unwraps the async value type. if (auto value = argType.dyn_cast()) { - if (*op.getResultType() != value.getValueType()) - return op.emitOpError() - << "result type " << *op.getResultType() - << " does not match async value type " << value.getValueType(); + if (*getResultType() != value.getValueType()) + return emitOpError() << "result type " << *getResultType() + << " does not match async value type " + << value.getValueType(); } return success(); diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -38,18 +38,18 @@ return false; } -static LogicalResult verify(ConstantOp op) { - ArrayAttr arrayAttr = op.getValue(); +LogicalResult ConstantOp::verify() { + ArrayAttr arrayAttr = getValue(); if (arrayAttr.size() != 2) { - return op.emitOpError( + return emitOpError( "requires 'value' to be a complex constant, represented as array of " "two values"); } - auto complexEltTy = op.getType().getElementType(); + auto complexEltTy = getType().getElementType(); if (complexEltTy != arrayAttr[0].getType() || complexEltTy != arrayAttr[1].getType()) { - return op.emitOpError() + return emitOpError() << "requires attribute's element types (" << arrayAttr[0].getType() << ", " << arrayAttr[1].getType() << ") to match the element type of the op's return type (" diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -48,16 +48,16 @@ // ApplyOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ApplyOp op) { - StringRef applicableOperator = op.applicableOperator(); +LogicalResult ApplyOp::verify() { + StringRef applicableOperatorStr = applicableOperator(); // Applicable operator must not be empty. - if (applicableOperator.empty()) - return op.emitOpError("applicable operator must not be empty"); + if (applicableOperatorStr.empty()) + return emitOpError("applicable operator must not be empty"); // Only `*` and `&` are supported. - if (applicableOperator != "&" && applicableOperator != "*") - return op.emitOpError("applicable operator is illegal"); + if (applicableOperatorStr != "&" && applicableOperatorStr != "*") + return emitOpError("applicable operator is illegal"); return success(); } @@ -66,32 +66,32 @@ // CallOp //===----------------------------------------------------------------------===// -static LogicalResult verify(emitc::CallOp op) { +LogicalResult emitc::CallOp::verify() { // Callee must not be empty. - if (op.callee().empty()) - return op.emitOpError("callee must not be empty"); + if (callee().empty()) + return emitOpError("callee must not be empty"); - if (Optional argsAttr = op.args()) { + if (Optional argsAttr = args()) { for (Attribute arg : argsAttr.getValue()) { if (arg.getType().isa()) { int64_t index = arg.cast().getInt(); // Args with elements of type index must be in range // [0..operands.size). - if ((index < 0) || (index >= static_cast(op.getNumOperands()))) - return op.emitOpError("index argument is out of range"); + if ((index < 0) || (index >= static_cast(getNumOperands()))) + return emitOpError("index argument is out of range"); // Args with elements of type ArrayAttr must have a type. } else if (arg.isa() && arg.getType().isa()) { - return op.emitOpError("array argument has no type"); + return emitOpError("array argument has no type"); } } } - if (Optional templateArgsAttr = op.template_args()) { + if (Optional templateArgsAttr = template_args()) { for (Attribute tArg : templateArgsAttr.getValue()) { if (!tArg.isa() && !tArg.isa() && !tArg.isa() && !tArg.isa()) - return op.emitOpError("template argument has invalid type"); + return emitOpError("template argument has invalid type"); } } @@ -103,12 +103,12 @@ //===----------------------------------------------------------------------===// /// The constant op requires that the attribute's type matches the return type. -static LogicalResult verify(emitc::ConstantOp &op) { - Attribute value = op.value(); - Type type = op.getType(); +LogicalResult emitc::ConstantOp::verify() { + Attribute value = valueAttr(); + Type type = getType(); if (!value.getType().isa() && type != value.getType()) - return op.emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; + return emitOpError() << "requires attribute's type (" << value.getType() + << ") to match op's return type (" << type << ")"; return success(); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -270,46 +270,37 @@ return walkResult.wasInterrupted() ? failure() : success(); } -template -static LogicalResult verifyIndexOp(T op) { - auto dimension = op.dimension(); - if (dimension != "x" && dimension != "y" && dimension != "z") - return op.emitError("dimension \"") << dimension << "\" is invalid"; - return success(); -} - -static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { - if (allReduce.body().empty() != allReduce.op().hasValue()) - return allReduce.emitError( - "expected either an op attribute or a non-empty body"); - if (!allReduce.body().empty()) { - if (allReduce.body().getNumArguments() != 2) - return allReduce.emitError("expected two region arguments"); - for (auto argument : allReduce.body().getArguments()) { - if (argument.getType() != allReduce.getType()) - return allReduce.emitError("incorrect region argument type"); +LogicalResult gpu::AllReduceOp::verify() { + if (body().empty() != op().hasValue()) + return emitError("expected either an op attribute or a non-empty body"); + if (!body().empty()) { + if (body().getNumArguments() != 2) + return emitError("expected two region arguments"); + for (auto argument : body().getArguments()) { + if (argument.getType() != getType()) + return emitError("incorrect region argument type"); } unsigned yieldCount = 0; - for (Block &block : allReduce.body()) { + for (Block &block : body()) { if (auto yield = dyn_cast(block.getTerminator())) { if (yield.getNumOperands() != 1) - return allReduce.emitError("expected one gpu.yield operand"); - if (yield.getOperand(0).getType() != allReduce.getType()) - return allReduce.emitError("incorrect gpu.yield type"); + return emitError("expected one gpu.yield operand"); + if (yield.getOperand(0).getType() != getType()) + return emitError("incorrect gpu.yield type"); ++yieldCount; } } if (yieldCount == 0) - return allReduce.emitError("expected gpu.yield op in region"); + return emitError("expected gpu.yield op in region"); } else { - gpu::AllReduceOperation opName = *allReduce.op(); + gpu::AllReduceOperation opName = *op(); if ((opName == gpu::AllReduceOperation::AND || opName == gpu::AllReduceOperation::OR || opName == gpu::AllReduceOperation::XOR) && - !allReduce.getType().isa()) { - return allReduce.emitError() - << '`' << gpu::stringifyAllReduceOperation(opName) << '`' - << " accumulator is only compatible with Integer type"; + !getType().isa()) { + return emitError() + << '`' << gpu::stringifyAllReduceOperation(opName) + << "` accumulator is only compatible with Integer type"; } } return success(); @@ -411,20 +402,20 @@ return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; } -static LogicalResult verify(LaunchOp op) { +LogicalResult LaunchOp::verify() { // Kernel launch takes kNumConfigOperands leading operands for grid/block // sizes and transforms them into kNumConfigRegionAttributes region arguments // for block/thread identifiers and grid/block sizes. - if (!op.body().empty()) { - if (op.body().getNumArguments() != - LaunchOp::kNumConfigOperands + op.getNumOperands() - - (op.dynamicSharedMemorySize() ? 1 : 0)) - return op.emitOpError("unexpected number of region arguments"); + if (!body().empty()) { + if (body().getNumArguments() != LaunchOp::kNumConfigOperands + + getNumOperands() - + (dynamicSharedMemorySize() ? 1 : 0)) + return emitOpError("unexpected number of region arguments"); } // Block terminators without successors are expected to exit the kernel region // and must be `gpu.terminator`. - for (Block &block : op.body()) { + for (Block &block : body()) { if (block.empty()) continue; if (block.back().getNumSuccessors() != 0) @@ -434,7 +425,7 @@ .emitError() .append("expected '", gpu::TerminatorOp::getOperationName(), "' or a terminator with successors") - .attachNote(op.getLoc()) + .attachNote(getLoc()) .append("in '", LaunchOp::getOperationName(), "' body region"); } } @@ -650,21 +641,21 @@ return KernelDim3{operands[3], operands[4], operands[5]}; } -static LogicalResult verify(LaunchFuncOp op) { - auto module = op->getParentOfType(); +LogicalResult LaunchFuncOp::verify() { + auto module = (*this)->getParentOfType(); if (!module) - return op.emitOpError("expected to belong to a module"); + return emitOpError("expected to belong to a module"); if (!module->getAttrOfType( GPUDialect::getContainerModuleAttrName())) - return op.emitOpError( - "expected the closest surrounding module to have the '" + - GPUDialect::getContainerModuleAttrName() + "' attribute"); + return emitOpError("expected the closest surrounding module to have the '" + + GPUDialect::getContainerModuleAttrName() + + "' attribute"); - auto kernelAttr = op->getAttrOfType(op.getKernelAttrName()); + auto kernelAttr = (*this)->getAttrOfType(getKernelAttrName()); if (!kernelAttr) - return op.emitOpError("symbol reference attribute '" + - op.getKernelAttrName() + "' must be specified"); + return emitOpError("symbol reference attribute '" + getKernelAttrName() + + "' must be specified"); return success(); } @@ -945,25 +936,25 @@ // ReturnOp //===----------------------------------------------------------------------===// -static LogicalResult verify(gpu::ReturnOp returnOp) { - GPUFuncOp function = returnOp->getParentOfType(); +LogicalResult gpu::ReturnOp::verify() { + GPUFuncOp function = (*this)->getParentOfType(); FunctionType funType = function.getType(); - if (funType.getNumResults() != returnOp.operands().size()) - return returnOp.emitOpError() + if (funType.getNumResults() != operands().size()) + return emitOpError() .append("expected ", funType.getNumResults(), " result operands") .attachNote(function.getLoc()) .append("return type declared here"); for (const auto &pair : llvm::enumerate( - llvm::zip(function.getType().getResults(), returnOp.operands()))) { + llvm::zip(function.getType().getResults(), operands()))) { Type type; Value operand; std::tie(type, operand) = pair.value(); if (type != operand.getType()) - return returnOp.emitOpError() << "unexpected type `" << operand.getType() - << "' for operand #" << pair.index(); + return emitOpError() << "unexpected type `" << operand.getType() + << "' for operand #" << pair.index(); } return success(); } @@ -1014,15 +1005,15 @@ // GPUMemcpyOp //===----------------------------------------------------------------------===// -static LogicalResult verify(MemcpyOp op) { - auto srcType = op.src().getType(); - auto dstType = op.dst().getType(); +LogicalResult MemcpyOp::verify() { + auto srcType = src().getType(); + auto dstType = dst().getType(); if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType)) - return op.emitOpError("arguments have incompatible element type"); + return emitOpError("arguments have incompatible element type"); if (failed(verifyCompatibleShape(srcType, dstType))) - return op.emitOpError("arguments have incompatible shape"); + return emitOpError("arguments have incompatible shape"); return success(); } @@ -1056,26 +1047,26 @@ // GPU_SubgroupMmaLoadMatrixOp //===----------------------------------------------------------------------===// -static LogicalResult verify(SubgroupMmaLoadMatrixOp op) { - auto srcType = op.srcMemref().getType(); - auto resType = op.res().getType(); +LogicalResult SubgroupMmaLoadMatrixOp::verify() { + auto srcType = srcMemref().getType(); + auto resType = res().getType(); auto resMatrixType = resType.cast(); auto operand = resMatrixType.getOperand(); auto srcMemrefType = srcType.cast(); auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt(); if (!srcMemrefType.getLayout().isIdentity()) - return op.emitError("expected identity layout map for source memref"); + return emitError("expected identity layout map for source memref"); if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace && srcMemSpace != kGlobalMemorySpace) - return op.emitError( + return emitError( "source memorySpace kGenericMemorySpace, kSharedMemorySpace or " "kGlobalMemorySpace only allowed"); if (!operand.equals("AOp") && !operand.equals("BOp") && !operand.equals("COp")) - return op.emitError("only AOp, BOp and COp can be loaded"); + return emitError("only AOp, BOp and COp can be loaded"); return success(); } @@ -1084,23 +1075,22 @@ // GPU_SubgroupMmaStoreMatrixOp //===----------------------------------------------------------------------===// -static LogicalResult verify(SubgroupMmaStoreMatrixOp op) { - auto srcType = op.src().getType(); - auto dstType = op.dstMemref().getType(); +LogicalResult SubgroupMmaStoreMatrixOp::verify() { + auto srcType = src().getType(); + auto dstType = dstMemref().getType(); auto srcMatrixType = srcType.cast(); auto dstMemrefType = dstType.cast(); auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt(); if (!dstMemrefType.getLayout().isIdentity()) - return op.emitError("expected identity layout map for destination memref"); + return emitError("expected identity layout map for destination memref"); if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace && dstMemSpace != kGlobalMemorySpace) - return op.emitError( - "destination memorySpace of kGenericMemorySpace, " - "kGlobalMemorySpace or kSharedMemorySpace only allowed"); + return emitError("destination memorySpace of kGenericMemorySpace, " + "kGlobalMemorySpace or kSharedMemorySpace only allowed"); if (!srcMatrixType.getOperand().equals("COp")) - return op.emitError( + return emitError( "expected the operand matrix being stored to have 'COp' operand type"); return success(); @@ -1110,21 +1100,17 @@ // GPU_SubgroupMmaComputeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(SubgroupMmaComputeOp op) { +LogicalResult SubgroupMmaComputeOp::verify() { enum OperandMap { A, B, C }; SmallVector opTypes; - - auto populateOpInfo = [&opTypes, &op]() { - opTypes.push_back(op.opA().getType().cast()); - opTypes.push_back(op.opB().getType().cast()); - opTypes.push_back(op.opC().getType().cast()); - }; - populateOpInfo(); + opTypes.push_back(opA().getType().cast()); + opTypes.push_back(opB().getType().cast()); + opTypes.push_back(opC().getType().cast()); if (!opTypes[A].getOperand().equals("AOp") || !opTypes[B].getOperand().equals("BOp") || !opTypes[C].getOperand().equals("COp")) - return op.emitError("operands must be in the order AOp, BOp, COp"); + return emitError("operands must be in the order AOp, BOp, COp"); ArrayRef aShape, bShape, cShape; aShape = opTypes[A].getShape(); @@ -1133,7 +1119,7 @@ if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || bShape[1] != cShape[1]) - return op.emitError("operand shapes do not satisfy matmul constraints"); + return emitError("operand shapes do not satisfy matmul constraints"); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -334,18 +334,17 @@ p.printNewline(); } -static LogicalResult verify(SwitchOp op) { - if ((!op.getCaseValues() && !op.getCaseDestinations().empty()) || - (op.getCaseValues() && - op.getCaseValues()->size() != - static_cast(op.getCaseDestinations().size()))) - return op.emitOpError("expects number of case values to match number of " - "case destinations"); - if (op.getBranchWeights() && - op.getBranchWeights()->size() != op.getNumSuccessors()) - return op.emitError("expects number of branch weights to match number of " - "successors: ") - << op.getBranchWeights()->size() << " vs " << op.getNumSuccessors(); +LogicalResult SwitchOp::verify() { + if ((!getCaseValues() && !getCaseDestinations().empty()) || + (getCaseValues() && + getCaseValues()->size() != + static_cast(getCaseDestinations().size()))) + return emitOpError("expects number of case values to match number of " + "case destinations"); + if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) + return emitError("expects number of branch weights to match number of " + "successors: ") + << getBranchWeights()->size() << " vs " << getNumSuccessors(); return success(); } @@ -518,11 +517,11 @@ }); } -LogicalResult verify(LLVM::GEPOp gepOp) { +LogicalResult LLVM::GEPOp::verify() { SmallVector indices; SmallVector structSizes; - findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes); - DenseIntElementsAttr structIndices = gepOp.getStructIndices(); + findKnownStructIndices(getBase().getType(), indices, &structSizes); + DenseIntElementsAttr structIndices = getStructIndices(); for (unsigned i : llvm::seq(0, indices.size())) { unsigned index = indices[i]; // GEP may not be indexing as deep as some structs nested in the type. @@ -531,11 +530,11 @@ int32_t staticIndex = structIndices.getValues()[index]; if (staticIndex == LLVM::GEPOp::kDynamicIndex) - return gepOp.emitOpError() << "expected index " << index - << " indexing a struct to be constant"; + return emitOpError() << "expected index " << index + << " indexing a struct to be constant"; if (staticIndex < 0 || static_cast(staticIndex) >= structSizes[i]) - return gepOp.emitOpError() - << "index " << index << " indexing a struct is out of bounds"; + return emitOpError() << "index " << index + << " indexing a struct is out of bounds"; } return success(); } @@ -613,9 +612,7 @@ return success(); } -static LogicalResult verify(LoadOp op) { - return verifyMemoryOpMetadata(op.getOperation()); -} +LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, Value addr, unsigned alignment, bool isVolatile, @@ -675,9 +672,7 @@ // Builder, printer and parser for LLVM::StoreOp. //===----------------------------------------------------------------------===// -static LogicalResult verify(StoreOp op) { - return verifyMemoryOpMetadata(op.getOperation()); -} +LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, Value addr, unsigned alignment, bool isVolatile, @@ -739,19 +734,18 @@ : getUnwindDestOperandsMutable(); } -static LogicalResult verify(InvokeOp op) { - if (op.getNumResults() > 1) - return op.emitOpError("must have 0 or 1 result"); +LogicalResult InvokeOp::verify() { + if (getNumResults() > 1) + return emitOpError("must have 0 or 1 result"); - Block *unwindDest = op.getUnwindDest(); + Block *unwindDest = getUnwindDest(); if (unwindDest->empty()) - return op.emitError( - "must have at least one operation in unwind destination"); + return emitError("must have at least one operation in unwind destination"); // In unwind destination, first operation must be LandingpadOp if (!isa(unwindDest->front())) - return op.emitError("first operation in unwind destination should be a " - "llvm.landingpad operation"); + return emitError("first operation in unwind destination should be a " + "llvm.landingpad operation"); return success(); } @@ -880,20 +874,20 @@ /// Verifying/Printing/Parsing for LLVM::LandingpadOp. ///===----------------------------------------------------------------------===// -static LogicalResult verify(LandingpadOp op) { +LogicalResult LandingpadOp::verify() { Value value; - if (LLVMFuncOp func = op->getParentOfType()) { + if (LLVMFuncOp func = (*this)->getParentOfType()) { if (!func.getPersonality().hasValue()) - return op.emitError( + return emitError( "llvm.landingpad needs to be in a function with a personality"); } - if (!op.getCleanup() && op.getOperands().empty()) - return op.emitError("landingpad instruction expects at least one clause or " - "cleanup attribute"); + if (!getCleanup() && getOperands().empty()) + return emitError("landingpad instruction expects at least one clause or " + "cleanup attribute"); - for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) { - value = op.getOperand(idx); + for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { + value = getOperand(idx); bool isFilter = value.getType().isa(); if (isFilter) { // FIXME: Verify filter clauses when arrays are appropriately handled @@ -903,8 +897,7 @@ if (auto bcOp = value.getDefiningOp()) { if (auto addrOp = bcOp.getArg().getDefiningOp()) continue; - return op.emitError("constant clauses expected") - .attachNote(bcOp.getLoc()) + return emitError("constant clauses expected").attachNote(bcOp.getLoc()) << "global addresses expected as operand to " "bitcast used in clauses for landingpad"; } @@ -913,7 +906,7 @@ continue; if (value.getDefiningOp()) continue; - return op.emitError("clause #") + return emitError("clause #") << idx << " is not a known constant - null, addressof, bitcast"; } } @@ -970,9 +963,9 @@ // Verifying/Printing/parsing for LLVM::CallOp. //===----------------------------------------------------------------------===// -static LogicalResult verify(CallOp &op) { - if (op.getNumResults() > 1) - return op.emitOpError("must have 0 or 1 result"); +LogicalResult CallOp::verify() { + if (getNumResults() > 1) + return emitOpError("must have 0 or 1 result"); // Type for the callee, we'll get it differently depending if it is a direct // or indirect call. @@ -981,75 +974,73 @@ bool isIndirect = false; // If this is an indirect call, the callee attribute is missing. - FlatSymbolRefAttr calleeName = op.getCalleeAttr(); + FlatSymbolRefAttr calleeName = getCalleeAttr(); if (!calleeName) { isIndirect = true; - if (!op.getNumOperands()) - return op.emitOpError( + if (!getNumOperands()) + return emitOpError( "must have either a `callee` attribute or at least an operand"); - auto ptrType = op.getOperand(0).getType().dyn_cast(); + auto ptrType = getOperand(0).getType().dyn_cast(); if (!ptrType) - return op.emitOpError("indirect call expects a pointer as callee: ") + return emitOpError("indirect call expects a pointer as callee: ") << ptrType; fnType = ptrType.getElementType(); } else { Operation *callee = - SymbolTable::lookupNearestSymbolFrom(op, calleeName.getAttr()); + SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); if (!callee) - return op.emitOpError() + return emitOpError() << "'" << calleeName.getValue() << "' does not reference a symbol in the current scope"; auto fn = dyn_cast(callee); if (!fn) - return op.emitOpError() << "'" << calleeName.getValue() - << "' does not reference a valid LLVM function"; + return emitOpError() << "'" << calleeName.getValue() + << "' does not reference a valid LLVM function"; fnType = fn.getType(); } LLVMFunctionType funcType = fnType.dyn_cast(); if (!funcType) - return op.emitOpError("callee does not have a functional type: ") << fnType; + return emitOpError("callee does not have a functional type: ") << fnType; // Verify that the operand and result types match the callee. if (!funcType.isVarArg() && - funcType.getNumParams() != (op.getNumOperands() - isIndirect)) - return op.emitOpError() - << "incorrect number of operands (" - << (op.getNumOperands() - isIndirect) - << ") for callee (expecting: " << funcType.getNumParams() << ")"; - - if (funcType.getNumParams() > (op.getNumOperands() - isIndirect)) - return op.emitOpError() << "incorrect number of operands (" - << (op.getNumOperands() - isIndirect) - << ") for varargs callee (expecting at least: " - << funcType.getNumParams() << ")"; + funcType.getNumParams() != (getNumOperands() - isIndirect)) + return emitOpError() << "incorrect number of operands (" + << (getNumOperands() - isIndirect) + << ") for callee (expecting: " + << funcType.getNumParams() << ")"; + + if (funcType.getNumParams() > (getNumOperands() - isIndirect)) + return emitOpError() << "incorrect number of operands (" + << (getNumOperands() - isIndirect) + << ") for varargs callee (expecting at least: " + << funcType.getNumParams() << ")"; for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) - if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i)) - return op.emitOpError() << "operand type mismatch for operand " << i - << ": " << op.getOperand(i + isIndirect).getType() - << " != " << funcType.getParamType(i); + if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) + return emitOpError() << "operand type mismatch for operand " << i << ": " + << getOperand(i + isIndirect).getType() + << " != " << funcType.getParamType(i); - if (op.getNumResults() == 0 && + if (getNumResults() == 0 && !funcType.getReturnType().isa()) - return op.emitOpError() << "expected function call to produce a value"; + return emitOpError() << "expected function call to produce a value"; - if (op.getNumResults() != 0 && + if (getNumResults() != 0 && funcType.getReturnType().isa()) - return op.emitOpError() + return emitOpError() << "calling function with void result must not produce values"; - if (op.getNumResults() > 1) - return op.emitOpError() + if (getNumResults() > 1) + return emitOpError() << "expected LLVM function call to produce 0 or 1 result"; - if (op.getNumResults() && - op.getResult(0).getType() != funcType.getReturnType()) - return op.emitOpError() - << "result type mismatch: " << op.getResult(0).getType() - << " != " << funcType.getReturnType(); + if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) + return emitOpError() << "result type mismatch: " << getResult(0).getType() + << " != " << funcType.getReturnType(); return success(); } @@ -1200,17 +1191,17 @@ return success(); } -static LogicalResult verify(ExtractElementOp op) { - Type vectorType = op.getVector().getType(); +LogicalResult ExtractElementOp::verify() { + Type vectorType = getVector().getType(); if (!LLVM::isCompatibleVectorType(vectorType)) - return op->emitOpError("expected LLVM dialect-compatible vector type for " - "operand #1, got") + return emitOpError("expected LLVM dialect-compatible vector type for " + "operand #1, got") << vectorType; Type valueType = LLVM::getVectorElementType(vectorType); - if (valueType != op.getRes().getType()) - return op.emitOpError() << "Type mismatch: extracting from " << vectorType - << " should produce " << valueType - << " but this op returns " << op.getRes().getType(); + if (valueType != getRes().getType()) + return emitOpError() << "Type mismatch: extracting from " << vectorType + << " should produce " << valueType + << " but this op returns " << getRes().getType(); return success(); } @@ -1367,17 +1358,17 @@ return {}; } -static LogicalResult verify(ExtractValueOp op) { - Type valueType = getInsertExtractValueElementType(op.getContainer().getType(), - op.getPositionAttr(), op); +LogicalResult ExtractValueOp::verify() { + Type valueType = getInsertExtractValueElementType(getContainer().getType(), + getPositionAttr(), *this); if (!valueType) return failure(); - if (op.getRes().getType() != valueType) - return op.emitOpError() - << "Type mismatch: extracting from " << op.getContainer().getType() - << " should produce " << valueType << " but this op returns " - << op.getRes().getType(); + if (getRes().getType() != valueType) + return emitOpError() << "Type mismatch: extracting from " + << getContainer().getType() << " should produce " + << valueType << " but this op returns " + << getRes().getType(); return success(); } @@ -1423,14 +1414,15 @@ return success(); } -static LogicalResult verify(InsertElementOp op) { - Type valueType = LLVM::getVectorElementType(op.getVector().getType()); - if (valueType != op.getValue().getType()) - return op.emitOpError() - << "Type mismatch: cannot insert " << op.getValue().getType() - << " into " << op.getVector().getType(); +LogicalResult InsertElementOp::verify() { + Type valueType = LLVM::getVectorElementType(getVector().getType()); + if (valueType != getValue().getType()) + return emitOpError() << "Type mismatch: cannot insert " + << getValue().getType() << " into " + << getVector().getType(); return success(); } + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::InsertValueOp. //===----------------------------------------------------------------------===// @@ -1473,16 +1465,16 @@ return success(); } -static LogicalResult verify(InsertValueOp op) { - Type valueType = getInsertExtractValueElementType(op.getContainer().getType(), - op.getPositionAttr(), op); +LogicalResult InsertValueOp::verify() { + Type valueType = getInsertExtractValueElementType(getContainer().getType(), + getPositionAttr(), *this); if (!valueType) return failure(); - if (op.getValue().getType() != valueType) - return op.emitOpError() - << "Type mismatch: cannot insert " << op.getValue().getType() - << " into " << op.getContainer().getType(); + if (getValue().getType() != valueType) + return emitOpError() << "Type mismatch: cannot insert " + << getValue().getType() << " into " + << getContainer().getType(); return success(); } @@ -1519,28 +1511,28 @@ return success(); } -static LogicalResult verify(ReturnOp op) { - if (op->getNumOperands() > 1) - return op->emitOpError("expected at most 1 operand"); +LogicalResult ReturnOp::verify() { + if (getNumOperands() > 1) + return emitOpError("expected at most 1 operand"); - if (auto parent = op->getParentOfType()) { + if (auto parent = (*this)->getParentOfType()) { Type expectedType = parent.getType().getReturnType(); if (expectedType.isa()) { - if (op->getNumOperands() == 0) + if (getNumOperands() == 0) return success(); - InFlightDiagnostic diag = op->emitOpError("expected no operands"); + InFlightDiagnostic diag = emitOpError("expected no operands"); diag.attachNote(parent->getLoc()) << "when returning from function"; return diag; } - if (op->getNumOperands() == 0) { + if (getNumOperands() == 0) { if (expectedType.isa()) return success(); - InFlightDiagnostic diag = op->emitOpError("expected 1 operand"); + InFlightDiagnostic diag = emitOpError("expected 1 operand"); diag.attachNote(parent->getLoc()) << "when returning from function"; return diag; } - if (expectedType != op->getOperand(0).getType()) { - InFlightDiagnostic diag = op->emitOpError("mismatching result types"); + if (expectedType != getOperand(0).getType()) { + InFlightDiagnostic diag = emitOpError("mismatching result types"); diag.attachNote(parent->getLoc()) << "when returning from function"; return diag; } @@ -1548,6 +1540,17 @@ return success(); } +//===----------------------------------------------------------------------===// +// ResumeOp +//===----------------------------------------------------------------------===// + +LogicalResult ResumeOp::verify() { + if (!getValue().getDefiningOp()) + return emitOpError("expects landingpad value as operand"); + // No check for personality of function - landingpad op verifies it. + return success(); +} + //===----------------------------------------------------------------------===// // Verifier for LLVM::AddressOfOp. //===----------------------------------------------------------------------===// @@ -1572,22 +1575,22 @@ getGlobalName()); } -static LogicalResult verify(AddressOfOp op) { - auto global = op.getGlobal(); - auto function = op.getFunction(); +LogicalResult AddressOfOp::verify() { + auto global = getGlobal(); + auto function = getFunction(); if (!global && !function) - return op.emitOpError( + return emitOpError( "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); if (global && LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) != - op.getResult().getType()) - return op.emitOpError( + getResult().getType()) + return emitOpError( "the type must be a pointer to the type of the referenced global"); - if (function && LLVM::LLVMPointerType::get(function.getType()) != - op.getResult().getType()) - return op.emitOpError( + if (function && + LLVM::LLVMPointerType::get(function.getType()) != getResult().getType()) + return emitOpError( "the type must be a pointer to the type of the referenced function"); return success(); @@ -1791,60 +1794,60 @@ return false; } -static LogicalResult verify(GlobalOp op) { - if (!LLVMPointerType::isValidElementType(op.getType())) - return op.emitOpError( +LogicalResult GlobalOp::verify() { + if (!LLVMPointerType::isValidElementType(getType())) + return emitOpError( "expects type to be a valid element type for an LLVM pointer"); - if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp())) - return op.emitOpError("must appear at the module level"); + if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) + return emitOpError("must appear at the module level"); - if (auto strAttr = op.getValueOrNull().dyn_cast_or_null()) { - auto type = op.getType().dyn_cast(); + if (auto strAttr = getValueOrNull().dyn_cast_or_null()) { + auto type = getType().dyn_cast(); IntegerType elementType = type ? type.getElementType().dyn_cast() : nullptr; if (!elementType || elementType.getWidth() != 8 || type.getNumElements() != strAttr.getValue().size()) - return op.emitOpError( + return emitOpError( "requires an i8 array type of the length equal to that of the string " "attribute"); } - if (Block *b = op.getInitializerBlock()) { + if (Block *b = getInitializerBlock()) { ReturnOp ret = cast(b->getTerminator()); if (ret.operand_type_begin() == ret.operand_type_end()) - return op.emitOpError("initializer region cannot return void"); - if (*ret.operand_type_begin() != op.getType()) - return op.emitOpError("initializer region type ") + return emitOpError("initializer region cannot return void"); + if (*ret.operand_type_begin() != getType()) + return emitOpError("initializer region type ") << *ret.operand_type_begin() << " does not match global type " - << op.getType(); + << getType(); - if (op.getValueOrNull()) - return op.emitOpError("cannot have both initializer value and region"); + if (getValueOrNull()) + return emitOpError("cannot have both initializer value and region"); } - if (op.getLinkage() == Linkage::Common) { - if (Attribute value = op.getValueOrNull()) { + if (getLinkage() == Linkage::Common) { + if (Attribute value = getValueOrNull()) { if (!isZeroAttribute(value)) { - return op.emitOpError() + return emitOpError() << "expected zero value for '" << stringifyLinkage(Linkage::Common) << "' linkage"; } } } - if (op.getLinkage() == Linkage::Appending) { - if (!op.getType().isa()) { - return op.emitOpError() - << "expected array type for '" - << stringifyLinkage(Linkage::Appending) << "' linkage"; + if (getLinkage() == Linkage::Appending) { + if (!getType().isa()) { + return emitOpError() << "expected array type for '" + << stringifyLinkage(Linkage::Appending) + << "' linkage"; } } - Optional alignAttr = op.getAlignment(); + Optional alignAttr = getAlignment(); if (alignAttr.hasValue()) { uint64_t value = alignAttr.getValue(); if (!llvm::isPowerOf2_64(value)) - return op->emitError() << "alignment attribute is not a power of 2"; + return emitError() << "alignment attribute is not a power of 2"; } return success(); @@ -1864,9 +1867,9 @@ return success(); } -static LogicalResult verify(GlobalCtorsOp op) { - if (op.getCtors().size() != op.getPriorities().size()) - return op.emitError( +LogicalResult GlobalCtorsOp::verify() { + if (getCtors().size() != getPriorities().size()) + return emitError( "mismatch between the number of ctors and the number of priorities"); return success(); } @@ -1885,9 +1888,9 @@ return success(); } -static LogicalResult verify(GlobalDtorsOp op) { - if (op.getDtors().size() != op.getPriorities().size()) - return op.emitError( +LogicalResult GlobalDtorsOp::verify() { + if (getDtors().size() != getPriorities().size()) + return emitError( "mismatch between the number of dtors and the number of priorities"); return success(); } @@ -1940,6 +1943,14 @@ return success(); } +LogicalResult ShuffleVectorOp::verify() { + Type type1 = getV1().getType(); + Type type2 = getV2().getType(); + if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) + return emitOpError("expected matching LLVM IR Dialect element types"); + return success(); +} + //===----------------------------------------------------------------------===// // Implementations for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// @@ -2117,42 +2128,43 @@ // - external functions have 'external' or 'extern_weak' linkage; // - vararg is (currently) only supported for external functions; // - entry block arguments are of LLVM types and match the function signature. -static LogicalResult verify(LLVMFuncOp op) { - if (op.getLinkage() == LLVM::Linkage::Common) - return op.emitOpError() - << "functions cannot have '" - << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; +LogicalResult LLVMFuncOp::verify() { + if (getLinkage() == LLVM::Linkage::Common) + return emitOpError() << "functions cannot have '" + << stringifyLinkage(LLVM::Linkage::Common) + << "' linkage"; // Check to see if this function has a void return with a result attribute to // it. It isn't clear what semantics we would assign to that. - if (op.getType().getReturnType().isa() && - !op.getResultAttrs(0).empty()) { - return op.emitOpError() + if (getType().getReturnType().isa() && + !getResultAttrs(0).empty()) { + return emitOpError() << "cannot attach result attributes to functions with a void return"; } - if (op.isExternal()) { - if (op.getLinkage() != LLVM::Linkage::External && - op.getLinkage() != LLVM::Linkage::ExternWeak) - return op.emitOpError() - << "external functions must have '" - << stringifyLinkage(LLVM::Linkage::External) << "' or '" - << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage"; + if (isExternal()) { + if (getLinkage() != LLVM::Linkage::External && + getLinkage() != LLVM::Linkage::ExternWeak) + return emitOpError() << "external functions must have '" + << stringifyLinkage(LLVM::Linkage::External) + << "' or '" + << stringifyLinkage(LLVM::Linkage::ExternWeak) + << "' linkage"; return success(); } - if (op.isVarArg()) - return op.emitOpError("only external functions can be variadic"); + if (isVarArg()) + return emitOpError("only external functions can be variadic"); - unsigned numArguments = op.getType().getNumParams(); - Block &entryBlock = op.front(); + unsigned numArguments = getType().getNumParams(); + Block &entryBlock = front(); for (unsigned i = 0; i < numArguments; ++i) { Type argType = entryBlock.getArgument(i).getType(); if (!isCompatibleType(argType)) - return op.emitOpError("entry block argument #") + return emitOpError("entry block argument #") << i << " is not of LLVM type"; - if (op.getType().getParamType(i) != argType) - return op.emitOpError("the type of entry block argument #") + if (getType().getParamType(i) != argType) + return emitOpError("the type of entry block argument #") << i << " does not match the function signature"; } @@ -2163,42 +2175,42 @@ // Verification for LLVM::ConstantOp. //===----------------------------------------------------------------------===// -static LogicalResult verify(LLVM::ConstantOp op) { - if (StringAttr sAttr = op.getValue().dyn_cast()) { - auto arrayType = op.getType().dyn_cast(); +LogicalResult LLVM::ConstantOp::verify() { + if (StringAttr sAttr = getValue().dyn_cast()) { + auto arrayType = getType().dyn_cast(); if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || !arrayType.getElementType().isInteger(8)) { - return op->emitOpError() - << "expected array type of " << sAttr.getValue().size() - << " i8 elements for the string constant"; + return emitOpError() << "expected array type of " + << sAttr.getValue().size() + << " i8 elements for the string constant"; } return success(); } - if (auto structType = op.getType().dyn_cast()) { + if (auto structType = getType().dyn_cast()) { if (structType.getBody().size() != 2 || structType.getBody()[0] != structType.getBody()[1]) { - return op.emitError() << "expected struct type with two elements of the " - "same type, the type of a complex constant"; + return emitError() << "expected struct type with two elements of the " + "same type, the type of a complex constant"; } - auto arrayAttr = op.getValue().dyn_cast(); + auto arrayAttr = getValue().dyn_cast(); if (!arrayAttr || arrayAttr.size() != 2 || arrayAttr[0].getType() != arrayAttr[1].getType()) { - return op.emitOpError() << "expected array attribute with two elements, " - "representing a complex constant"; + return emitOpError() << "expected array attribute with two elements, " + "representing a complex constant"; } Type elementType = structType.getBody()[0]; if (!elementType .isa()) { - return op.emitError() + return emitError() << "expected struct element types to be floating point type or " "integer type"; } return success(); } - if (!op.getValue().isa()) - return op.emitOpError() + if (!getValue().isa()) + return emitOpError() << "only supports integer, float, string or elements attributes"; return success(); } @@ -2294,42 +2306,40 @@ return success(); } -static LogicalResult verify(AtomicRMWOp op) { - auto ptrType = op.getPtr().getType().cast(); - auto valType = op.getVal().getType(); +LogicalResult AtomicRMWOp::verify() { + auto ptrType = getPtr().getType().cast(); + auto valType = getVal().getType(); if (valType != ptrType.getElementType()) - return op.emitOpError("expected LLVM IR element type for operand #0 to " - "match type for operand #1"); - auto resType = op.getRes().getType(); + return emitOpError("expected LLVM IR element type for operand #0 to " + "match type for operand #1"); + auto resType = getRes().getType(); if (resType != valType) - return op.emitOpError( + return emitOpError( "expected LLVM IR result type to match type for operand #1"); - if (op.getBinOp() == AtomicBinOp::fadd || - op.getBinOp() == AtomicBinOp::fsub) { + if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) - return op.emitOpError("expected LLVM IR floating point type"); - } else if (op.getBinOp() == AtomicBinOp::xchg) { + return emitOpError("expected LLVM IR floating point type"); + } else if (getBinOp() == AtomicBinOp::xchg) { auto intType = valType.dyn_cast(); unsigned intBitWidth = intType ? intType.getWidth() : 0; if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && !valType.isa() && !valType.isa() && !valType.isa() && !valType.isa()) - return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); + return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); } else { auto intType = valType.dyn_cast(); unsigned intBitWidth = intType ? intType.getWidth() : 0; if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64) - return op.emitOpError("expected LLVM IR integer type"); + return emitOpError("expected LLVM IR integer type"); } - if (static_cast(op.getOrdering()) < + if (static_cast(getOrdering()) < static_cast(AtomicOrdering::monotonic)) - return op.emitOpError() - << "expected at least '" - << stringifyAtomicOrdering(AtomicOrdering::monotonic) - << "' ordering"; + return emitOpError() << "expected at least '" + << stringifyAtomicOrdering(AtomicOrdering::monotonic) + << "' ordering"; return success(); } @@ -2375,28 +2385,28 @@ return success(); } -static LogicalResult verify(AtomicCmpXchgOp op) { - auto ptrType = op.getPtr().getType().cast(); +LogicalResult AtomicCmpXchgOp::verify() { + auto ptrType = getPtr().getType().cast(); if (!ptrType) - return op.emitOpError("expected LLVM IR pointer type for operand #0"); - auto cmpType = op.getCmp().getType(); - auto valType = op.getVal().getType(); + return emitOpError("expected LLVM IR pointer type for operand #0"); + auto cmpType = getCmp().getType(); + auto valType = getVal().getType(); if (cmpType != ptrType.getElementType() || cmpType != valType) - return op.emitOpError("expected LLVM IR element type for operand #0 to " - "match type for all other operands"); + return emitOpError("expected LLVM IR element type for operand #0 to " + "match type for all other operands"); auto intType = valType.dyn_cast(); unsigned intBitWidth = intType ? intType.getWidth() : 0; if (!valType.isa() && intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && !valType.isa() && !valType.isa() && !valType.isa() && !valType.isa()) - return op.emitOpError("unexpected LLVM IR type"); - if (op.getSuccessOrdering() < AtomicOrdering::monotonic || - op.getFailureOrdering() < AtomicOrdering::monotonic) - return op.emitOpError("ordering must be at least 'monotonic'"); - if (op.getFailureOrdering() == AtomicOrdering::release || - op.getFailureOrdering() == AtomicOrdering::acq_rel) - return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); + return emitOpError("unexpected LLVM IR type"); + if (getSuccessOrdering() < AtomicOrdering::monotonic || + getFailureOrdering() < AtomicOrdering::monotonic) + return emitOpError("ordering must be at least 'monotonic'"); + if (getFailureOrdering() == AtomicOrdering::release || + getFailureOrdering() == AtomicOrdering::acq_rel) + return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); return success(); } @@ -2432,12 +2442,12 @@ p << stringifyAtomicOrdering(op.getOrdering()); } -static LogicalResult verify(FenceOp &op) { - if (op.getOrdering() == AtomicOrdering::not_atomic || - op.getOrdering() == AtomicOrdering::unordered || - op.getOrdering() == AtomicOrdering::monotonic) - return op.emitOpError("can be given only acquire, release, acq_rel, " - "and seq_cst orderings"); +LogicalResult FenceOp::verify() { + if (getOrdering() == AtomicOrdering::not_atomic || + getOrdering() == AtomicOrdering::unordered || + getOrdering() == AtomicOrdering::monotonic) + return emitOpError("can be given only acquire, release, acq_rel, " + "and seq_cst orderings"); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -62,8 +62,14 @@ parser.getNameLoc(), result.operands)); } -static LogicalResult verify(MmaOp op) { - MLIRContext *context = op.getContext(); +LogicalResult CpAsyncOp::verify() { + if (size() != 4 && size() != 8 && size() != 16) + return emitError("expected byte size to be either 4, 8 or 16."); + return success(); +} + +LogicalResult MmaOp::verify() { + MLIRContext *context = getContext(); auto f16Ty = Float16Type::get(context); auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); auto f32Ty = Float32Type::get(context); @@ -72,44 +78,55 @@ auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); - SmallVector operandTypes(op.getOperandTypes().begin(), - op.getOperandTypes().end()); + auto operandTypes = getOperandTypes(); if (operandTypes != SmallVector(8, f16x2Ty) && - operandTypes != SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, - f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty}) { - return op.emitOpError( - "expected operands to be 4 s followed by either " - "4 s or 8 floats"); + operandTypes != ArrayRef{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty, + f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty}) { + return emitOpError("expected operands to be 4 s followed by either " + "4 s or 8 floats"); } - if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) { - return op.emitOpError("expected result type to be a struct of either 4 " - "s or 8 floats"); + if (getType() != f32x8StructTy && getType() != f16x2x4StructTy) { + return emitOpError("expected result type to be a struct of either 4 " + "s or 8 floats"); } - auto alayout = op->getAttrOfType("alayout"); - auto blayout = op->getAttrOfType("blayout"); + auto alayout = (*this)->getAttrOfType("alayout"); + auto blayout = (*this)->getAttrOfType("blayout"); if (!(alayout && blayout) || !(alayout.getValue() == "row" || alayout.getValue() == "col") || !(blayout.getValue() == "row" || blayout.getValue() == "col")) { - return op.emitOpError( - "alayout and blayout attributes must be set to either " - "\"row\" or \"col\""); + return emitOpError("alayout and blayout attributes must be set to either " + "\"row\" or \"col\""); } - if (operandTypes == SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, - f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty} && - op.getType() == f32x8StructTy && alayout.getValue() == "row" && + if (operandTypes == ArrayRef{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty, + f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty} && + getType() == f32x8StructTy && alayout.getValue() == "row" && blayout.getValue() == "col") { return success(); } - return op.emitOpError("unimplemented mma.sync variant"); + return emitOpError("unimplemented mma.sync variant"); +} + +LogicalResult ShflOp::verify() { + if (!(*this)->getAttrOfType("return_value_and_is_valid")) + return success(); + auto type = getType().dyn_cast(); + auto elementType = (type && type.getBody().size() == 2) + ? type.getBody()[1].dyn_cast() + : nullptr; + if (!elementType || elementType.getWidth() != 1) + return emitError("expected return type to be a two-element struct with " + "i1 as the second element"); + return success(); } -std::pair -inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) { +std::pair NVVM::inferMMAType(NVVM::MMATypes type, + NVVM::MMAFrag frag, + MLIRContext *context) { unsigned numberElements = 0; Type elementType; OpBuilder builder(context); @@ -131,76 +148,72 @@ return std::make_pair(elementType, numberElements); } -static LogicalResult verify(NVVM::WMMALoadOp op) { +LogicalResult NVVM::WMMALoadOp::verify() { unsigned addressSpace = - op.ptr().getType().cast().getAddressSpace(); + ptr().getType().cast().getAddressSpace(); if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) - return op.emitOpError("expected source pointer in memory " - "space 0, 1, 3"); + return emitOpError("expected source pointer in memory " + "space 0, 1, 3"); - if (NVVM::WMMALoadOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(), - op.eltype(), op.frag()) == 0) - return op.emitOpError() << "invalid attribute combination"; + if (NVVM::WMMALoadOp::getIntrinsicID(m(), n(), k(), layout(), eltype(), + frag()) == 0) + return emitOpError() << "invalid attribute combination"; std::pair typeInfo = - inferMMAType(op.eltype(), op.frag(), op.getContext()); + inferMMAType(eltype(), frag(), getContext()); Type dstType = LLVM::LLVMStructType::getLiteral( - op.getContext(), SmallVector(typeInfo.second, typeInfo.first)); - if (op.getType() != dstType) - return op.emitOpError("expected destination type is a structure of ") + getContext(), SmallVector(typeInfo.second, typeInfo.first)); + if (getType() != dstType) + return emitOpError("expected destination type is a structure of ") << typeInfo.second << " elements of type " << typeInfo.first; return success(); } -static LogicalResult verify(NVVM::WMMAStoreOp op) { +LogicalResult NVVM::WMMAStoreOp::verify() { unsigned addressSpace = - op.ptr().getType().cast().getAddressSpace(); + ptr().getType().cast().getAddressSpace(); if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) - return op.emitOpError("expected operands to be a source pointer in memory " - "space 0, 1, 3"); + return emitOpError("expected operands to be a source pointer in memory " + "space 0, 1, 3"); - if (NVVM::WMMAStoreOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(), - op.eltype()) == 0) - return op.emitOpError() << "invalid attribute combination"; + if (NVVM::WMMAStoreOp::getIntrinsicID(m(), n(), k(), layout(), eltype()) == 0) + return emitOpError() << "invalid attribute combination"; std::pair typeInfo = - inferMMAType(op.eltype(), NVVM::MMAFrag::c, op.getContext()); - if (op.args().size() != typeInfo.second) - return op.emitOpError() - << "expected " << typeInfo.second << " data operands"; - if (llvm::any_of(op.args(), [&typeInfo](Value operands) { + inferMMAType(eltype(), NVVM::MMAFrag::c, getContext()); + if (args().size() != typeInfo.second) + return emitOpError() << "expected " << typeInfo.second << " data operands"; + if (llvm::any_of(args(), [&typeInfo](Value operands) { return operands.getType() != typeInfo.first; })) - return op.emitOpError() - << "expected data operands of type " << typeInfo.first; + return emitOpError() << "expected data operands of type " << typeInfo.first; return success(); } -static LogicalResult verify(NVVM::WMMAMmaOp op) { - if (NVVM::WMMAMmaOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layoutA(), - op.layoutB(), op.eltypeA(), - op.eltypeB()) == 0) - return op.emitOpError() << "invalid attribute combination"; +LogicalResult NVVM::WMMAMmaOp::verify() { + if (NVVM::WMMAMmaOp::getIntrinsicID(m(), n(), k(), layoutA(), layoutB(), + eltypeA(), eltypeB()) == 0) + return emitOpError() << "invalid attribute combination"; std::pair typeInfoA = - inferMMAType(op.eltypeA(), NVVM::MMAFrag::a, op.getContext()); + inferMMAType(eltypeA(), NVVM::MMAFrag::a, getContext()); std::pair typeInfoB = - inferMMAType(op.eltypeA(), NVVM::MMAFrag::b, op.getContext()); + inferMMAType(eltypeA(), NVVM::MMAFrag::b, getContext()); std::pair typeInfoC = - inferMMAType(op.eltypeB(), NVVM::MMAFrag::c, op.getContext()); + inferMMAType(eltypeB(), NVVM::MMAFrag::c, getContext()); SmallVector arguments; arguments.append(typeInfoA.second, typeInfoA.first); arguments.append(typeInfoB.second, typeInfoB.first); arguments.append(typeInfoC.second, typeInfoC.first); unsigned numArgs = arguments.size(); - if (op.args().size() != numArgs) - return op.emitOpError() << "expected " << numArgs << " arguments"; + if (args().size() != numArgs) + return emitOpError() << "expected " << numArgs << " arguments"; for (unsigned i = 0; i < numArgs; i++) { - if (op.args()[i].getType() != arguments[i]) - return op.emitOpError() - << "expected argument " << i << " to be of type " << arguments[i]; + if (args()[i].getType() != arguments[i]) + return emitOpError() << "expected argument " << i << " to be of type " + << arguments[i]; } Type dstType = LLVM::LLVMStructType::getLiteral( - op.getContext(), SmallVector(typeInfoC.second, typeInfoC.first)); - if (op.getType() != dstType) - return op.emitOpError("expected destination type is a structure of ") + getContext(), SmallVector(typeInfoC.second, typeInfoC.first)); + if (getType() != dstType) + return emitOpError("expected destination type is a structure of ") << typeInfoC.second << " elements of type " << typeInfoC.first; return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -400,11 +400,11 @@ /// FillOp region is elided when printing. void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} -static LogicalResult verify(FillOp op) { - OpOperand *output = op.getOutputOperand(0); - Type fillType = op.value().getType(); +LogicalResult FillOp::verify() { + OpOperand *output = getOutputOperand(0); + Type fillType = value().getType(); if (getElementTypeOrSelf(output->get()) != fillType) - return op.emitOpError("expects fill type to match view elemental type"); + return emitOpError("expects fill type to match view elemental type"); return success(); } @@ -635,7 +635,7 @@ return success(); } -static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } +LogicalResult GenericOp::verify() { return verifyGenericOp(*this); } namespace { // Deduplicate redundant args of a linalg generic op. @@ -811,25 +811,24 @@ result.addAttributes(attrs); } -static LogicalResult verify(InitTensorOp op) { - RankedTensorType resultType = op.getType(); +LogicalResult InitTensorOp::verify() { + RankedTensorType resultType = getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( - op.static_sizes().cast(), + static_sizes().cast(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); - if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(), - op.static_sizes(), op.sizes(), - ShapedType::isDynamic))) + if (failed(verifyListOfOperandsOrIntegers( + *this, "sizes", resultType.getRank(), static_sizes(), sizes(), + ShapedType::isDynamic))) return failure(); - if (op.static_sizes().size() != static_cast(resultType.getRank())) - return op->emitError("expected ") - << resultType.getRank() << " sizes values"; + if (static_sizes().size() != static_cast(resultType.getRank())) + return emitError("expected ") << resultType.getRank() << " sizes values"; Type expectedType = InitTensorOp::inferResultType( staticSizes, resultType.getElementType(), resultType.getEncoding()); if (resultType != expectedType) { - return op.emitError("specified type ") + return emitError("specified type ") << resultType << " does not match the inferred type " << expectedType; } @@ -1030,13 +1029,13 @@ return success(); } -static LogicalResult verify(linalg::YieldOp op) { - auto *parentOp = op->getParentOp(); +LogicalResult linalg::YieldOp::verify() { + auto *parentOp = (*this)->getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) - return op.emitOpError("expected single non-empty parent region"); + return emitOpError("expected single non-empty parent region"); if (auto linalgOp = dyn_cast(parentOp)) - return verifyYield(op, cast(parentOp)); + return verifyYield(*this, cast(parentOp)); if (auto tiledLoopOp = dyn_cast(parentOp)) { // Check if output args with tensor types match results types. @@ -1044,25 +1043,25 @@ llvm::copy_if( tiledLoopOp.outputs(), std::back_inserter(tensorOuts), [&](Value out) { return out.getType().isa(); }); - if (tensorOuts.size() != op.values().size()) - return op.emitOpError("expected number of tensor output args = ") - << tensorOuts.size() << " to match the number of yield operands = " - << op.values().size(); + if (tensorOuts.size() != values().size()) + return emitOpError("expected number of tensor output args = ") + << tensorOuts.size() + << " to match the number of yield operands = " << values().size(); TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts)); for (auto &item : - llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) { + llvm::enumerate(llvm::zip(tensorTypes, getOperandTypes()))) { Type outType, resultType; unsigned index = item.index(); std::tie(outType, resultType) = item.value(); if (outType != resultType) - return op.emitOpError("expected yield operand ") + return emitOpError("expected yield operand ") << index << " with type = " << resultType << " to match output arg type = " << outType; } return success(); } - return op.emitOpError("expected parent op with LinalgOp interface"); + return emitOpError("expected parent op with LinalgOp interface"); } //===----------------------------------------------------------------------===// @@ -1316,37 +1315,37 @@ return !region().isAncestor(value.getParentRegion()); } -static LogicalResult verify(TiledLoopOp op) { +LogicalResult TiledLoopOp::verify() { // Check if iterator types are provided for every loop dimension. - if (op.iterator_types().size() != op.getNumLoops()) - return op.emitOpError("expected iterator types array attribute size = ") - << op.iterator_types().size() - << " to match the number of loops = " << op.getNumLoops(); + if (iterator_types().size() != getNumLoops()) + return emitOpError("expected iterator types array attribute size = ") + << iterator_types().size() + << " to match the number of loops = " << getNumLoops(); // Check if types of input arguments match region args types. for (auto &item : - llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) { + llvm::enumerate(llvm::zip(inputs(), getRegionInputArgs()))) { Value input, inputRegionArg; unsigned index = item.index(); std::tie(input, inputRegionArg) = item.value(); if (input.getType() != inputRegionArg.getType()) - return op.emitOpError("expected input arg ") + return emitOpError("expected input arg ") << index << " with type = " << input.getType() - << " to match region arg " << index + op.getNumLoops() + << " to match region arg " << index + getNumLoops() << " type = " << inputRegionArg.getType(); } // Check if types of input arguments match region args types. for (auto &item : - llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) { + llvm::enumerate(llvm::zip(outputs(), getRegionOutputArgs()))) { Value output, outputRegionArg; unsigned index = item.index(); std::tie(output, outputRegionArg) = item.value(); if (output.getType() != outputRegionArg.getType()) - return op.emitOpError("expected output arg ") + return emitOpError("expected output arg ") << index << " with type = " << output.getType() << " to match region arg " - << index + op.getNumLoops() + op.inputs().size() + << index + getNumLoops() + inputs().size() << " type = " << outputRegionArg.getType(); } return success(); @@ -1667,13 +1666,13 @@ // IndexOp //===----------------------------------------------------------------------===// -static LogicalResult verify(IndexOp op) { - auto linalgOp = dyn_cast(op->getParentOp()); +LogicalResult IndexOp::verify() { + auto linalgOp = dyn_cast((*this)->getParentOp()); if (!linalgOp) - return op.emitOpError("expected parent op with LinalgOp interface"); - if (linalgOp.getNumLoops() <= op.dim()) - return op.emitOpError("expected dim (") - << op.dim() << ") to be lower than the number of loops (" + return emitOpError("expected parent op with LinalgOp interface"); + if (linalgOp.getNumLoops() <= dim()) + return emitOpError("expected dim (") + << dim() << ") to be lower than the number of loops (" << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; return success(); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -67,6 +67,10 @@ return NoneType::get(type.getContext()); } +LogicalResult memref::CastOp::verify() { + return impl::verifyCastOp(*this, areCastCompatible); +} + //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// @@ -95,15 +99,15 @@ return success(); } -static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } +LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); } -static LogicalResult verify(AllocaOp op) { +LogicalResult AllocaOp::verify() { // An alloca op needs to have an ancestor with an allocation scope trait. - if (!op->getParentWithTrait()) - return op.emitOpError( + if (!(*this)->getParentWithTrait()) + return emitOpError( "requires an ancestor op with AutomaticAllocationScope trait"); - return verifyAllocLikeOp(op); + return verifyAllocLikeOp(*this); } namespace { @@ -246,11 +250,8 @@ return success(); } -static LogicalResult verify(AllocaScopeOp op) { - if (failed(RegionBranchOpInterface::verifyTypes(op))) - return failure(); - - return success(); +LogicalResult AllocaScopeOp::verify() { + return RegionBranchOpInterface::verifyTypes(*this); } void AllocaScopeOp::getSuccessorRegions( @@ -268,10 +269,9 @@ // AssumeAlignmentOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AssumeAlignmentOp op) { - unsigned alignment = op.alignment(); - if (!llvm::isPowerOf2_32(alignment)) - return op.emitOpError("alignment must be power of 2"); +LogicalResult AssumeAlignmentOp::verify() { + if (!llvm::isPowerOf2_32(alignment())) + return emitOpError("alignment must be power of 2"); return success(); } @@ -556,17 +556,17 @@ return {}; } -static LogicalResult verify(DimOp op) { +LogicalResult DimOp::verify() { // Assume unknown index to be in range. - Optional index = op.getConstantIndex(); + Optional index = getConstantIndex(); if (!index.hasValue()) return success(); // Check that constant index is not knowingly out of range. - auto type = op.source().getType(); + auto type = source().getType(); if (auto memrefType = type.dyn_cast()) { if (index.getValue() >= memrefType.getRank()) - return op.emitOpError("index is out of range"); + return emitOpError("index is out of range"); } else if (type.isa()) { // Assume index to be in range. } else { @@ -866,67 +866,66 @@ return success(); } -static LogicalResult verify(DmaStartOp op) { - unsigned numOperands = op.getNumOperands(); +LogicalResult DmaStartOp::verify() { + unsigned numOperands = getNumOperands(); // Mandatory non-variadic operands are: src memref, dst memref, tag memref and // the number of elements. if (numOperands < 4) - return op.emitOpError("expected at least 4 operands"); + return emitOpError("expected at least 4 operands"); // Check types of operands. The order of these calls is important: the later // calls rely on some type properties to compute the operand position. // 1. Source memref. - if (!op.getSrcMemRef().getType().isa()) - return op.emitOpError("expected source to be of memref type"); - if (numOperands < op.getSrcMemRefRank() + 4) - return op.emitOpError() - << "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; - if (!op.getSrcIndices().empty() && - !llvm::all_of(op.getSrcIndices().getTypes(), + if (!getSrcMemRef().getType().isa()) + return emitOpError("expected source to be of memref type"); + if (numOperands < getSrcMemRefRank() + 4) + return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 + << " operands"; + if (!getSrcIndices().empty() && + !llvm::all_of(getSrcIndices().getTypes(), [](Type t) { return t.isIndex(); })) - return op.emitOpError("expected source indices to be of index type"); + return emitOpError("expected source indices to be of index type"); // 2. Destination memref. - if (!op.getDstMemRef().getType().isa()) - return op.emitOpError("expected destination to be of memref type"); - unsigned numExpectedOperands = - op.getSrcMemRefRank() + op.getDstMemRefRank() + 4; + if (!getDstMemRef().getType().isa()) + return emitOpError("expected destination to be of memref type"); + unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; if (numOperands < numExpectedOperands) - return op.emitOpError() - << "expected at least " << numExpectedOperands << " operands"; - if (!op.getDstIndices().empty() && - !llvm::all_of(op.getDstIndices().getTypes(), + return emitOpError() << "expected at least " << numExpectedOperands + << " operands"; + if (!getDstIndices().empty() && + !llvm::all_of(getDstIndices().getTypes(), [](Type t) { return t.isIndex(); })) - return op.emitOpError("expected destination indices to be of index type"); + return emitOpError("expected destination indices to be of index type"); // 3. Number of elements. - if (!op.getNumElements().getType().isIndex()) - return op.emitOpError("expected num elements to be of index type"); + if (!getNumElements().getType().isIndex()) + return emitOpError("expected num elements to be of index type"); // 4. Tag memref. - if (!op.getTagMemRef().getType().isa()) - return op.emitOpError("expected tag to be of memref type"); - numExpectedOperands += op.getTagMemRefRank(); + if (!getTagMemRef().getType().isa()) + return emitOpError("expected tag to be of memref type"); + numExpectedOperands += getTagMemRefRank(); if (numOperands < numExpectedOperands) - return op.emitOpError() - << "expected at least " << numExpectedOperands << " operands"; - if (!op.getTagIndices().empty() && - !llvm::all_of(op.getTagIndices().getTypes(), + return emitOpError() << "expected at least " << numExpectedOperands + << " operands"; + if (!getTagIndices().empty() && + !llvm::all_of(getTagIndices().getTypes(), [](Type t) { return t.isIndex(); })) - return op.emitOpError("expected tag indices to be of index type"); + return emitOpError("expected tag indices to be of index type"); // Optional stride-related operands must be either both present or both // absent. if (numOperands != numExpectedOperands && numOperands != numExpectedOperands + 2) - return op.emitOpError("incorrect number of operands"); + return emitOpError("incorrect number of operands"); // 5. Strides. - if (op.isStrided()) { - if (!op.getStride().getType().isIndex() || - !op.getNumElementsPerStride().getType().isIndex()) - return op.emitOpError( + if (isStrided()) { + if (!getStride().getType().isIndex() || + !getNumElementsPerStride().getType().isIndex()) + return emitOpError( "expected stride and num elements per stride to be of type index"); } @@ -949,14 +948,14 @@ return foldMemRefCast(*this); } -static LogicalResult verify(DmaWaitOp op) { +LogicalResult DmaWaitOp::verify() { // Check that the number of tag indices matches the tagMemRef rank. - unsigned numTagIndices = op.tagIndices().size(); - unsigned tagMemRefRank = op.getTagMemRefRank(); + unsigned numTagIndices = tagIndices().size(); + unsigned tagMemRefRank = getTagMemRefRank(); if (numTagIndices != tagMemRefRank) - return op.emitOpError() << "expected tagIndices to have the same number of " - "elements as the tagMemRef rank, expected " - << tagMemRefRank << ", but got " << numTagIndices; + return emitOpError() << "expected tagIndices to have the same number of " + "elements as the tagMemRef rank, expected " + << tagMemRefRank << ", but got " << numTagIndices; return success(); } @@ -979,14 +978,13 @@ } } -static LogicalResult verify(GenericAtomicRMWOp op) { - auto &body = op.getRegion(); +LogicalResult GenericAtomicRMWOp::verify() { + auto &body = getRegion(); if (body.getNumArguments() != 1) - return op.emitOpError("expected single number of entry block arguments"); + return emitOpError("expected single number of entry block arguments"); - if (op.getResult().getType() != body.getArgument(0).getType()) - return op.emitOpError( - "expected block argument of the same type result type"); + if (getResult().getType() != body.getArgument(0).getType()) + return emitOpError("expected block argument of the same type result type"); bool hasSideEffects = body.walk([&](Operation *nestedOp) { @@ -1034,12 +1032,12 @@ // AtomicYieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AtomicYieldOp op) { - Type parentType = op->getParentOp()->getResultTypes().front(); - Type resultType = op.result().getType(); +LogicalResult AtomicYieldOp::verify() { + Type parentType = (*this)->getParentOp()->getResultTypes().front(); + Type resultType = result().getType(); if (parentType != resultType) - return op.emitOpError() << "types mismatch between yield op: " << resultType - << " and its parent: " << parentType; + return emitOpError() << "types mismatch between yield op: " << resultType + << " and its parent: " << parentType; return success(); } @@ -1090,19 +1088,19 @@ return success(); } -static LogicalResult verify(GlobalOp op) { - auto memrefType = op.type().dyn_cast(); +LogicalResult GlobalOp::verify() { + auto memrefType = type().dyn_cast(); if (!memrefType || !memrefType.hasStaticShape()) - return op.emitOpError("type should be static shaped memref, but got ") - << op.type(); + return emitOpError("type should be static shaped memref, but got ") + << type(); // Verify that the initial value, if present, is either a unit attribute or // an elements attribute. - if (op.initial_value().hasValue()) { - Attribute initValue = op.initial_value().getValue(); + if (initial_value().hasValue()) { + Attribute initValue = initial_value().getValue(); if (!initValue.isa() && !initValue.isa()) - return op.emitOpError("initial value should be a unit or elements " - "attribute, but got ") + return emitOpError("initial value should be a unit or elements " + "attribute, but got ") << initValue; // Check that the type of the initial value is compatible with the type of @@ -1111,17 +1109,17 @@ Type initType = initValue.getType(); Type tensorType = getTensorTypeFromMemRefType(memrefType); if (initType != tensorType) - return op.emitOpError("initial value expected to be of type ") + return emitOpError("initial value expected to be of type ") << tensorType << ", but was of type " << initType; } } - if (Optional alignAttr = op.alignment()) { + if (Optional alignAttr = alignment()) { uint64_t alignment = alignAttr.getValue(); if (!llvm::isPowerOf2_64(alignment)) - return op->emitError() << "alignment attribute value " << alignment - << " is not a power of 2"; + return emitError() << "alignment attribute value " << alignment + << " is not a power of 2"; } // TODO: verify visibility for declarations. @@ -1154,9 +1152,9 @@ // LoadOp //===----------------------------------------------------------------------===// -static LogicalResult verify(LoadOp op) { - if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) - return op.emitOpError("incorrect number of indices for load"); +LogicalResult LoadOp::verify() { + if (getNumOperands() != 1 + getMemRefType().getRank()) + return emitOpError("incorrect number of indices for load"); return success(); } @@ -1224,9 +1222,9 @@ return success(); } -static LogicalResult verify(PrefetchOp op) { - if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) - return op.emitOpError("too few indices"); +LogicalResult PrefetchOp::verify() { + if (getNumOperands() != 1 + getMemRefType().getRank()) + return emitOpError("too few indices"); return success(); } @@ -1306,26 +1304,25 @@ // TODO: ponder whether we want to allow missing trailing sizes/strides that are // completed automatically, like we have for subview and extract_slice. -static LogicalResult verify(ReinterpretCastOp op) { +LogicalResult ReinterpretCastOp::verify() { // The source and result memrefs should be in the same memory space. - auto srcType = op.source().getType().cast(); - auto resultType = op.getType().cast(); + auto srcType = source().getType().cast(); + auto resultType = getType().cast(); if (srcType.getMemorySpace() != resultType.getMemorySpace()) - return op.emitError("different memory spaces specified for source type ") + return emitError("different memory spaces specified for source type ") << srcType << " and result memref type " << resultType; if (srcType.getElementType() != resultType.getElementType()) - return op.emitError("different element types specified for source type ") + return emitError("different element types specified for source type ") << srcType << " and result memref type " << resultType; // Match sizes in result memref type and in static_sizes attribute. - for (auto &en : - llvm::enumerate(llvm::zip(resultType.getShape(), - extractFromI64ArrayAttr(op.static_sizes())))) { + for (auto &en : llvm::enumerate(llvm::zip( + resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) { int64_t resultSize = std::get<0>(en.value()); int64_t expectedSize = std::get<1>(en.value()); if (!ShapedType::isDynamic(resultSize) && !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) - return op.emitError("expected result type with size = ") + return emitError("expected result type with size = ") << expectedSize << " instead of " << resultSize << " in dim = " << en.index(); } @@ -1336,27 +1333,26 @@ int64_t resultOffset; SmallVector resultStrides; if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) - return op.emitError( - "expected result type to have strided layout but found ") + return emitError("expected result type to have strided layout but found ") << resultType; // Match offset in result memref type and in static_offsets attribute. - int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); + int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front(); if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && !ShapedType::isDynamicStrideOrOffset(expectedOffset) && resultOffset != expectedOffset) - return op.emitError("expected result type with offset = ") + return emitError("expected result type with offset = ") << resultOffset << " instead of " << expectedOffset; // Match strides in result memref type and in static_strides attribute. for (auto &en : llvm::enumerate(llvm::zip( - resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { + resultStrides, extractFromI64ArrayAttr(static_strides())))) { int64_t resultStride = std::get<0>(en.value()); int64_t expectedStride = std::get<1>(en.value()); if (!ShapedType::isDynamicStrideOrOffset(resultStride) && !ShapedType::isDynamicStrideOrOffset(expectedStride) && resultStride != expectedStride) - return op.emitError("expected result type with stride = ") + return emitError("expected result type with stride = ") << expectedStride << " instead of " << resultStride << " in dim = " << en.index(); } @@ -1532,8 +1528,8 @@ return success(); } -static LogicalResult verify(ExpandShapeOp op) { - return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); +LogicalResult ExpandShapeOp::verify() { + return verifyReshapeOp(*this, getResultType(), getSrcType()); } void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -1542,8 +1538,8 @@ CollapseMixedReshapeOps>(context); } -static LogicalResult verify(CollapseShapeOp op) { - return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); +LogicalResult CollapseShapeOp::verify() { + return verifyReshapeOp(*this, getSrcType(), getResultType()); } struct CollapseShapeOpMemRefCastFolder @@ -1593,32 +1589,30 @@ // ReshapeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReshapeOp op) { - Type operandType = op.source().getType(); - Type resultType = op.result().getType(); +LogicalResult ReshapeOp::verify() { + Type operandType = source().getType(); + Type resultType = result().getType(); Type operandElementType = operandType.cast().getElementType(); Type resultElementType = resultType.cast().getElementType(); if (operandElementType != resultElementType) - return op.emitOpError("element types of source and destination memref " - "types should be the same"); + return emitOpError("element types of source and destination memref " + "types should be the same"); if (auto operandMemRefType = operandType.dyn_cast()) if (!operandMemRefType.getLayout().isIdentity()) - return op.emitOpError( - "source memref type should have identity affine map"); + return emitOpError("source memref type should have identity affine map"); - int64_t shapeSize = op.shape().getType().cast().getDimSize(0); + int64_t shapeSize = shape().getType().cast().getDimSize(0); auto resultMemRefType = resultType.dyn_cast(); if (resultMemRefType) { if (!resultMemRefType.getLayout().isIdentity()) - return op.emitOpError( - "result memref type should have identity affine map"); + return emitOpError("result memref type should have identity affine map"); if (shapeSize == ShapedType::kDynamicSize) - return op.emitOpError("cannot use shape operand with dynamic length to " - "reshape to statically-ranked memref type"); + return emitOpError("cannot use shape operand with dynamic length to " + "reshape to statically-ranked memref type"); if (shapeSize != resultMemRefType.getRank()) - return op.emitOpError( + return emitOpError( "length of shape operand differs from the result's memref rank"); } return success(); @@ -1628,9 +1622,9 @@ // StoreOp //===----------------------------------------------------------------------===// -static LogicalResult verify(StoreOp op) { - if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) - return op.emitOpError("store index operand count not equal to memref rank"); +LogicalResult StoreOp::verify() { + if (getNumOperands() != 2 + getMemRefType().getRank()) + return emitOpError("store index operand count not equal to memref rank"); return success(); } @@ -1951,29 +1945,29 @@ } /// Verifier for SubViewOp. -static LogicalResult verify(SubViewOp op) { - MemRefType baseType = op.getSourceType(); - MemRefType subViewType = op.getType(); +LogicalResult SubViewOp::verify() { + MemRefType baseType = getSourceType(); + MemRefType subViewType = getType(); // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != subViewType.getMemorySpace()) - return op.emitError("different memory spaces specified for base memref " - "type ") + return emitError("different memory spaces specified for base memref " + "type ") << baseType << " and subview memref type " << subViewType; // Verify that the base memref type has a strided layout map. if (!isStrided(baseType)) - return op.emitError("base type ") << baseType << " is not strided"; + return emitError("base type ") << baseType << " is not strided"; // Verify result type against inferred type. auto expectedType = SubViewOp::inferResultType( - baseType, extractFromI64ArrayAttr(op.static_offsets()), - extractFromI64ArrayAttr(op.static_sizes()), - extractFromI64ArrayAttr(op.static_strides())); + baseType, extractFromI64ArrayAttr(static_offsets()), + extractFromI64ArrayAttr(static_sizes()), + extractFromI64ArrayAttr(static_strides())); auto result = isRankReducedMemRefType(expectedType.cast(), - subViewType, op.getMixedSizes()); - return produceSubViewErrorMsg(result, op, expectedType); + subViewType, getMixedSizes()); + return produceSubViewErrorMsg(result, *this, expectedType); } raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { @@ -2278,18 +2272,17 @@ return success(); } -static LogicalResult verify(TransposeOp op) { - if (!op.permutation().isPermutation()) - return op.emitOpError("expected a permutation map"); - if (op.permutation().getNumDims() != op.getShapedType().getRank()) - return op.emitOpError( - "expected a permutation map of same rank as the input"); +LogicalResult TransposeOp::verify() { + if (!permutation().isPermutation()) + return emitOpError("expected a permutation map"); + if (permutation().getNumDims() != getShapedType().getRank()) + return emitOpError("expected a permutation map of same rank as the input"); - auto srcType = op.in().getType().cast(); - auto dstType = op.getType().cast(); - auto transposedType = inferTransposeResultType(srcType, op.permutation()); + auto srcType = in().getType().cast(); + auto dstType = getType().cast(); + auto transposedType = inferTransposeResultType(srcType, permutation()); if (dstType != transposedType) - return op.emitOpError("output type ") + return emitOpError("output type ") << dstType << " does not match transposed input type " << srcType << ", " << transposedType; return success(); @@ -2338,29 +2331,28 @@ p << " : " << op.getOperand(0).getType() << " to " << op.getType(); } -static LogicalResult verify(ViewOp op) { - auto baseType = op.getOperand(0).getType().cast(); - auto viewType = op.getType(); +LogicalResult ViewOp::verify() { + auto baseType = getOperand(0).getType().cast(); + auto viewType = getType(); // The base memref should have identity layout map (or none). if (!baseType.getLayout().isIdentity()) - return op.emitError("unsupported map for base memref type ") << baseType; + return emitError("unsupported map for base memref type ") << baseType; // The result memref should have identity layout map (or none). if (!viewType.getLayout().isIdentity()) - return op.emitError("unsupported map for result memref type ") << viewType; + return emitError("unsupported map for result memref type ") << viewType; // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != viewType.getMemorySpace()) - return op.emitError("different memory spaces specified for base memref " - "type ") + return emitError("different memory spaces specified for base memref " + "type ") << baseType << " and view memref type " << viewType; // Verify that we have the correct number of sizes for the result type. unsigned numDynamicDims = viewType.getNumDynamicDims(); - if (op.sizes().size() != numDynamicDims) - return op.emitError("incorrect number of size operands for type ") - << viewType; + if (sizes().size() != numDynamicDims) + return emitError("incorrect number of size operands for type ") << viewType; return success(); } @@ -2467,19 +2459,19 @@ // AtomicRMWOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AtomicRMWOp op) { - if (op.getMemRefType().getRank() != op.getNumOperands() - 2) - return op.emitOpError( +LogicalResult AtomicRMWOp::verify() { + if (getMemRefType().getRank() != getNumOperands() - 2) + return emitOpError( "expects the number of subscripts to be equal to memref rank"); - switch (op.kind()) { + switch (kind()) { case arith::AtomicRMWKind::addf: case arith::AtomicRMWKind::maxf: case arith::AtomicRMWKind::minf: case arith::AtomicRMWKind::mulf: - if (!op.value().getType().isa()) - return op.emitOpError() - << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) - << "' expects a floating-point type"; + if (!value().getType().isa()) + return emitOpError() << "with kind '" + << arith::stringifyAtomicRMWKind(kind()) + << "' expects a floating-point type"; break; case arith::AtomicRMWKind::addi: case arith::AtomicRMWKind::maxs: @@ -2489,10 +2481,10 @@ case arith::AtomicRMWKind::muli: case arith::AtomicRMWKind::ori: case arith::AtomicRMWKind::andi: - if (!op.value().getType().isa()) - return op.emitOpError() - << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) - << "' expects an integer type"; + if (!value().getType().isa()) + return emitOpError() << "with kind '" + << arith::stringifyAtomicRMWKind(kind()) + << "' expects an integer type"; break; default: break; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -668,28 +668,22 @@ LoopOp::getOperandSegmentSizeAttr()}); } -static LogicalResult verifyLoopOp(acc::LoopOp loopOp) { +LogicalResult acc::LoopOp::verify() { // auto, independent and seq attribute are mutually exclusive. - if ((loopOp.auto_() && (loopOp.independent() || loopOp.seq())) || - (loopOp.independent() && loopOp.seq())) { - loopOp.emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " + + if ((auto_() && (independent() || seq())) || (independent() && seq())) { + return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " + acc::LoopOp::getIndependentAttrName() + ", " + acc::LoopOp::getSeqAttrName() + " can be present at the same time"); - return failure(); } // Gang, worker and vector are incompatible with seq. - if (loopOp.seq() && loopOp.exec_mapping() != OpenACCExecMapping::NONE) { - loopOp.emitError("gang, worker or vector cannot appear with the seq attr"); - return failure(); - } + if (seq() && exec_mapping() != OpenACCExecMapping::NONE) + return emitError("gang, worker or vector cannot appear with the seq attr"); // Check non-empty body(). - if (loopOp.region().empty()) { - loopOp.emitError("expected non-empty body."); - return failure(); - } + if (region().empty()) + return emitError("expected non-empty body."); return success(); } @@ -698,13 +692,13 @@ // DataOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::DataOp dataOp) { +LogicalResult acc::DataOp::verify() { // 2.6.5. Data Construct restriction // At least one copy, copyin, copyout, create, no_create, present, deviceptr, // attach, or default clause must appear on a data construct. - if (dataOp.getOperands().empty() && !dataOp.defaultAttr()) - return dataOp.emitError("at least one operand or the default attribute " - "must appear on the data operation"); + if (getOperands().empty() && !defaultAttr()) + return emitError("at least one operand or the default attribute " + "must appear on the data operation"); return success(); } @@ -726,28 +720,28 @@ // ExitDataOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::ExitDataOp op) { +LogicalResult acc::ExitDataOp::verify() { // 2.6.6. Data Exit Directive restriction // At least one copyout, delete, or detach clause must appear on an exit data // directive. - if (op.copyoutOperands().empty() && op.deleteOperands().empty() && - op.detachOperands().empty()) - return op.emitError( + if (copyoutOperands().empty() && deleteOperands().empty() && + detachOperands().empty()) + return emitError( "at least one operand in copyout, delete or detach must appear on the " "exit data operation"); // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. - if (op.asyncOperand() && op.async()) - return op.emitError("async attribute cannot appear with asyncOperand"); + if (asyncOperand() && async()) + return emitError("async attribute cannot appear with asyncOperand"); // The wait attribute represent the wait clause without values. Therefore the // attribute and operands cannot appear at the same time. - if (!op.waitOperands().empty() && op.wait()) - return op.emitError("wait attribute cannot appear with waitOperands"); + if (!waitOperands().empty() && wait()) + return emitError("wait attribute cannot appear with waitOperands"); - if (op.waitDevnum() && op.waitOperands().empty()) - return op.emitError("wait_devnum cannot appear without waitOperands"); + if (waitDevnum() && waitOperands().empty()) + return emitError("wait_devnum cannot appear without waitOperands"); return success(); } @@ -773,28 +767,28 @@ // EnterDataOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::EnterDataOp op) { +LogicalResult acc::EnterDataOp::verify() { // 2.6.6. Data Enter Directive restriction // At least one copyin, create, or attach clause must appear on an enter data // directive. - if (op.copyinOperands().empty() && op.createOperands().empty() && - op.createZeroOperands().empty() && op.attachOperands().empty()) - return op.emitError( + if (copyinOperands().empty() && createOperands().empty() && + createZeroOperands().empty() && attachOperands().empty()) + return emitError( "at least one operand in copyin, create, " "create_zero or attach must appear on the enter data operation"); // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. - if (op.asyncOperand() && op.async()) - return op.emitError("async attribute cannot appear with asyncOperand"); + if (asyncOperand() && async()) + return emitError("async attribute cannot appear with asyncOperand"); // The wait attribute represent the wait clause without values. Therefore the // attribute and operands cannot appear at the same time. - if (!op.waitOperands().empty() && op.wait()) - return op.emitError("wait attribute cannot appear with waitOperands"); + if (!waitOperands().empty() && wait()) + return emitError("wait attribute cannot appear with waitOperands"); - if (op.waitDevnum() && op.waitOperands().empty()) - return op.emitError("wait_devnum cannot appear without waitOperands"); + if (waitDevnum() && waitOperands().empty()) + return emitError("wait_devnum cannot appear without waitOperands"); return success(); } @@ -820,12 +814,11 @@ // InitOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::InitOp initOp) { - Operation *currOp = initOp; - while ((currOp = currOp->getParentOp())) { +LogicalResult acc::InitOp::verify() { + Operation *currOp = *this; + while ((currOp = currOp->getParentOp())) if (isComputeOperation(currOp)) - return initOp.emitOpError("cannot be nested in a compute operation"); - } + return emitOpError("cannot be nested in a compute operation"); return success(); } @@ -833,12 +826,11 @@ // ShutdownOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::ShutdownOp op) { - Operation *currOp = op; - while ((currOp = currOp->getParentOp())) { +LogicalResult acc::ShutdownOp::verify() { + Operation *currOp = *this; + while ((currOp = currOp->getParentOp())) if (isComputeOperation(currOp)) - return op.emitOpError("cannot be nested in a compute operation"); - } + return emitOpError("cannot be nested in a compute operation"); return success(); } @@ -846,25 +838,24 @@ // UpdateOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::UpdateOp updateOp) { +LogicalResult acc::UpdateOp::verify() { // At least one of host or device should have a value. - if (updateOp.hostOperands().empty() && updateOp.deviceOperands().empty()) - return updateOp.emitError("at least one value must be present in" - " hostOperands or deviceOperands"); + if (hostOperands().empty() && deviceOperands().empty()) + return emitError( + "at least one value must be present in hostOperands or deviceOperands"); // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. - if (updateOp.asyncOperand() && updateOp.async()) - return updateOp.emitError("async attribute cannot appear with " - " asyncOperand"); + if (asyncOperand() && async()) + return emitError("async attribute cannot appear with asyncOperand"); // The wait attribute represent the wait clause without values. Therefore the // attribute and operands cannot appear at the same time. - if (!updateOp.waitOperands().empty() && updateOp.wait()) - return updateOp.emitError("wait attribute cannot appear with waitOperands"); + if (!waitOperands().empty() && wait()) + return emitError("wait attribute cannot appear with waitOperands"); - if (updateOp.waitDevnum() && updateOp.waitOperands().empty()) - return updateOp.emitError("wait_devnum cannot appear without waitOperands"); + if (waitDevnum() && waitOperands().empty()) + return emitError("wait_devnum cannot appear without waitOperands"); return success(); } @@ -890,14 +881,14 @@ // WaitOp //===----------------------------------------------------------------------===// -static LogicalResult verify(acc::WaitOp waitOp) { +LogicalResult acc::WaitOp::verify() { // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. - if (waitOp.asyncOperand() && waitOp.async()) - return waitOp.emitError("async attribute cannot appear with asyncOperand"); + if (asyncOperand() && async()) + return emitError("async attribute cannot appear with asyncOperand"); - if (waitOp.waitDevnum() && waitOp.waitOperands().empty()) - return waitOp.emitError("wait_devnum cannot appear without waitOperands"); + if (waitDevnum() && waitOperands().empty()) + return emitError("wait_devnum cannot appear without waitOperands"); return success(); } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -165,9 +165,9 @@ } } -static LogicalResult verifyParallelOp(ParallelOp op) { - if (op.allocate_vars().size() != op.allocators_vars().size()) - return op.emitError( +LogicalResult ParallelOp::verify() { + if (allocate_vars().size() != allocators_vars().size()) + return emitError( "expected equal sizes for allocate and allocator variables"); return success(); } @@ -1072,31 +1072,31 @@ p.printRegion(op.region()); } -static LogicalResult verifySectionsOp(SectionsOp op) { - +LogicalResult SectionsOp::verify() { // A list item may not appear in more than one clause on the same directive, // except that it may be specified in both firstprivate and lastprivate // clauses. - for (auto var : op.private_vars()) { - if (llvm::is_contained(op.firstprivate_vars(), var)) - return op.emitOpError() + for (auto var : private_vars()) { + if (llvm::is_contained(firstprivate_vars(), var)) + return emitOpError() << "operand used in both private and firstprivate clauses"; - if (llvm::is_contained(op.lastprivate_vars(), var)) - return op.emitOpError() + if (llvm::is_contained(lastprivate_vars(), var)) + return emitOpError() << "operand used in both private and lastprivate clauses"; } - if (op.allocate_vars().size() != op.allocators_vars().size()) - return op.emitError( + if (allocate_vars().size() != allocators_vars().size()) + return emitError( "expected equal sizes for allocate and allocator variables"); - for (auto &inst : *op.region().begin()) { - if (!(isa(inst) || isa(inst))) - op.emitOpError() - << "expected omp.section op or terminator op inside region"; + for (auto &inst : *region().begin()) { + if (!(isa(inst) || isa(inst))) { + return emitOpError() + << "expected omp.section op or terminator op inside region"; + } } - return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); + return verifyReductionVarList(*this, reductions(), reduction_vars()); } /// Parses an OpenMP Workshare Loop operation @@ -1224,65 +1224,65 @@ printer.printRegion(region); } -static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { - if (op.initializerRegion().empty()) - return op.emitOpError() << "expects non-empty initializer region"; - Block &initializerEntryBlock = op.initializerRegion().front(); +LogicalResult ReductionDeclareOp::verify() { + if (initializerRegion().empty()) + return emitOpError() << "expects non-empty initializer region"; + Block &initializerEntryBlock = initializerRegion().front(); if (initializerEntryBlock.getNumArguments() != 1 || - initializerEntryBlock.getArgument(0).getType() != op.type()) { - return op.emitOpError() << "expects initializer region with one argument " - "of the reduction type"; + initializerEntryBlock.getArgument(0).getType() != type()) { + return emitOpError() << "expects initializer region with one argument " + "of the reduction type"; } - for (YieldOp yieldOp : op.initializerRegion().getOps()) { + for (YieldOp yieldOp : initializerRegion().getOps()) { if (yieldOp.results().size() != 1 || - yieldOp.results().getTypes()[0] != op.type()) - return op.emitOpError() << "expects initializer region to yield a value " - "of the reduction type"; + yieldOp.results().getTypes()[0] != type()) + return emitOpError() << "expects initializer region to yield a value " + "of the reduction type"; } - if (op.reductionRegion().empty()) - return op.emitOpError() << "expects non-empty reduction region"; - Block &reductionEntryBlock = op.reductionRegion().front(); + if (reductionRegion().empty()) + return emitOpError() << "expects non-empty reduction region"; + Block &reductionEntryBlock = reductionRegion().front(); if (reductionEntryBlock.getNumArguments() != 2 || reductionEntryBlock.getArgumentTypes()[0] != reductionEntryBlock.getArgumentTypes()[1] || - reductionEntryBlock.getArgumentTypes()[0] != op.type()) - return op.emitOpError() << "expects reduction region with two arguments of " - "the reduction type"; - for (YieldOp yieldOp : op.reductionRegion().getOps()) { + reductionEntryBlock.getArgumentTypes()[0] != type()) + return emitOpError() << "expects reduction region with two arguments of " + "the reduction type"; + for (YieldOp yieldOp : reductionRegion().getOps()) { if (yieldOp.results().size() != 1 || - yieldOp.results().getTypes()[0] != op.type()) - return op.emitOpError() << "expects reduction region to yield a value " - "of the reduction type"; + yieldOp.results().getTypes()[0] != type()) + return emitOpError() << "expects reduction region to yield a value " + "of the reduction type"; } - if (op.atomicReductionRegion().empty()) + if (atomicReductionRegion().empty()) return success(); - Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); + Block &atomicReductionEntryBlock = atomicReductionRegion().front(); if (atomicReductionEntryBlock.getNumArguments() != 2 || atomicReductionEntryBlock.getArgumentTypes()[0] != atomicReductionEntryBlock.getArgumentTypes()[1]) - return op.emitOpError() << "expects atomic reduction region with two " - "arguments of the same type"; + return emitOpError() << "expects atomic reduction region with two " + "arguments of the same type"; auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] .dyn_cast(); - if (!ptrType || ptrType.getElementType() != op.type()) - return op.emitOpError() << "expects atomic reduction region arguments to " - "be accumulators containing the reduction type"; + if (!ptrType || ptrType.getElementType() != type()) + return emitOpError() << "expects atomic reduction region arguments to " + "be accumulators containing the reduction type"; return success(); } -static LogicalResult verifyReductionOp(ReductionOp op) { +LogicalResult ReductionOp::verify() { // TODO: generalize this to an op interface when there is more than one op // that supports reductions. - auto container = op->getParentOfType(); + auto container = (*this)->getParentOfType(); for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) - if (container.reduction_vars()[i] == op.accumulator()) + if (container.reduction_vars()[i] == accumulator()) return success(); - return op.emitOpError() << "the accumulator is not used by the parent"; + return emitOpError() << "the accumulator is not used by the parent"; } //===----------------------------------------------------------------------===// @@ -1368,27 +1368,26 @@ } } -static LogicalResult verifyWsLoopOp(WsLoopOp op) { - return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); +LogicalResult WsLoopOp::verify() { + return verifyReductionVarList(*this, reductions(), reduction_vars()); } //===----------------------------------------------------------------------===// // Verifier for critical construct (2.17.1) //===----------------------------------------------------------------------===// -static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { - return verifySynchronizationHint(op, op.hint()); +LogicalResult CriticalDeclareOp::verify() { + return verifySynchronizationHint(*this, hint()); } -static LogicalResult verifyCriticalOp(CriticalOp op) { - - if (op.nameAttr()) { - auto symbolRef = op.nameAttr().cast(); - auto decl = - SymbolTable::lookupNearestSymbolFrom(op, symbolRef); +LogicalResult CriticalOp::verify() { + if (nameAttr()) { + SymbolRefAttr symbolRef = nameAttr(); + auto decl = SymbolTable::lookupNearestSymbolFrom( + *this, symbolRef); if (!decl) { - return op.emitOpError() << "expected symbol reference " << symbolRef - << " to point to a critical declaration"; + return emitOpError() << "expected symbol reference " << symbolRef + << " to point to a critical declaration"; } } @@ -1399,34 +1398,34 @@ // Verifier for ordered construct //===----------------------------------------------------------------------===// -static LogicalResult verifyOrderedOp(OrderedOp op) { - auto container = op->getParentOfType(); +LogicalResult OrderedOp::verify() { + auto container = (*this)->getParentOfType(); if (!container || !container.ordered_valAttr() || container.ordered_valAttr().getInt() == 0) - return op.emitOpError() << "ordered depend directive must be closely " - << "nested inside a worksharing-loop with ordered " - << "clause with parameter present"; + return emitOpError() << "ordered depend directive must be closely " + << "nested inside a worksharing-loop with ordered " + << "clause with parameter present"; if (container.ordered_valAttr().getInt() != - (int64_t)op.num_loops_val().getValue()) - return op.emitOpError() << "number of variables in depend clause does not " - << "match number of iteration variables in the " - << "doacross loop"; + (int64_t)num_loops_val().getValue()) + return emitOpError() << "number of variables in depend clause does not " + << "match number of iteration variables in the " + << "doacross loop"; return success(); } -static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) { +LogicalResult OrderedRegionOp::verify() { // TODO: The code generation for ordered simd directive is not supported yet. - if (op.simd()) + if (simd()) return failure(); - if (auto container = op->getParentOfType()) { + if (auto container = (*this)->getParentOfType()) { if (!container.ordered_valAttr() || container.ordered_valAttr().getInt() != 0) - return op.emitOpError() << "ordered region must be closely nested inside " - << "a worksharing-loop region with an ordered " - << "clause without parameter present"; + return emitOpError() << "ordered region must be closely nested inside " + << "a worksharing-loop region with an ordered " + << "clause without parameter present"; } return success(); @@ -1468,18 +1467,18 @@ } /// Verifier for AtomicReadOp -static LogicalResult verifyAtomicReadOp(AtomicReadOp op) { - if (auto mo = op.memory_order()) { +LogicalResult AtomicReadOp::verify() { + if (auto mo = memory_order()) { if (*mo == ClauseMemoryOrderKind::acq_rel || *mo == ClauseMemoryOrderKind::release) { - return op.emitError( + return emitError( "memory-order must not be acq_rel or release for atomic reads"); } } - if (op.x() == op.v()) - return op.emitError( + if (x() == v()) + return emitError( "read and write must not be to the same location for atomic reads"); - return verifySynchronizationHint(op, op.hint()); + return verifySynchronizationHint(*this, hint()); } //===----------------------------------------------------------------------===// @@ -1521,15 +1520,15 @@ } /// Verifier for AtomicWriteOp -static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) { - if (auto mo = op.memory_order()) { +LogicalResult AtomicWriteOp::verify() { + if (auto mo = memory_order()) { if (*mo == ClauseMemoryOrderKind::acq_rel || *mo == ClauseMemoryOrderKind::acquire) { - return op.emitError( + return emitError( "memory-order must not be acq_rel or acquire for atomic writes"); } } - return verifySynchronizationHint(op, op.hint()); + return verifySynchronizationHint(*this, hint()); } //===----------------------------------------------------------------------===// @@ -1601,11 +1600,11 @@ } /// Verifier for AtomicUpdateOp -static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) { - if (auto mo = op.memory_order()) { +LogicalResult AtomicUpdateOp::verify() { + if (auto mo = memory_order()) { if (*mo == ClauseMemoryOrderKind::acq_rel || *mo == ClauseMemoryOrderKind::acquire) { - return op.emitError( + return emitError( "memory-order must not be acq_rel or acquire for atomic updates"); } } @@ -1637,10 +1636,10 @@ } /// Verifier for AtomicCaptureOp -static LogicalResult verifyAtomicCaptureOp(AtomicCaptureOp op) { - Block::OpListType &ops = op.region().front().getOperations(); +LogicalResult AtomicCaptureOp::verify() { + Block::OpListType &ops = region().front().getOperations(); if (ops.size() != 3) - return emitError(op.getLoc()) + return emitError() << "expected three operations in omp.atomic.capture region (one " "terminator, and two atomic ops)"; auto &firstOp = ops.front(); @@ -1654,21 +1653,21 @@ if (!((firstUpdateStmt && secondReadStmt) || (firstReadStmt && secondUpdateStmt) || (firstReadStmt && secondWriteStmt))) - return emitError(ops.front().getLoc()) + return ops.front().emitError() << "invalid sequence of operations in the capture region"; if (firstUpdateStmt && secondReadStmt && firstUpdateStmt.x() != secondReadStmt.x()) - return emitError(firstUpdateStmt.getLoc()) + return firstUpdateStmt.emitError() << "updated variable in omp.atomic.update must be captured in " "second operation"; if (firstReadStmt && secondUpdateStmt && firstReadStmt.x() != secondUpdateStmt.x()) - return emitError(firstReadStmt.getLoc()) + return firstReadStmt.emitError() << "captured variable in omp.atomic.read must be updated in second " "operation"; if (firstReadStmt && secondWriteStmt && firstReadStmt.x() != secondWriteStmt.address()) - return emitError(firstReadStmt.getLoc()) + return firstReadStmt.emitError() << "captured variable in omp.atomic.read must be updated in " "second operation"; return success(); diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -90,9 +90,9 @@ // pdl::ApplyNativeConstraintOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ApplyNativeConstraintOp op) { - if (op.getNumOperands() == 0) - return op.emitOpError("expected at least one argument"); +LogicalResult ApplyNativeConstraintOp::verify() { + if (getNumOperands() == 0) + return emitOpError("expected at least one argument"); return success(); } @@ -100,9 +100,9 @@ // pdl::ApplyNativeRewriteOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ApplyNativeRewriteOp op) { - if (op.getNumOperands() == 0 && op.getNumResults() == 0) - return op.emitOpError("expected at least one argument or result"); +LogicalResult ApplyNativeRewriteOp::verify() { + if (getNumOperands() == 0 && getNumResults() == 0) + return emitOpError("expected at least one argument or result"); return success(); } @@ -110,18 +110,18 @@ // pdl::AttributeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AttributeOp op) { - Value attrType = op.type(); - Optional attrValue = op.value(); +LogicalResult AttributeOp::verify() { + Value attrType = type(); + Optional attrValue = value(); if (!attrValue) { - if (isa(op->getParentOp())) - return op.emitOpError("expected constant value when specified within a " - "`pdl.rewrite`"); - return verifyHasBindingUse(op); + if (isa((*this)->getParentOp())) + return emitOpError( + "expected constant value when specified within a `pdl.rewrite`"); + return verifyHasBindingUse(*this); } if (attrType) - return op.emitOpError("expected only one of [`type`, `value`] to be set"); + return emitOpError("expected only one of [`type`, `value`] to be set"); return success(); } @@ -129,13 +129,13 @@ // pdl::OperandOp //===----------------------------------------------------------------------===// -static LogicalResult verify(OperandOp op) { return verifyHasBindingUse(op); } +LogicalResult OperandOp::verify() { return verifyHasBindingUse(*this); } //===----------------------------------------------------------------------===// // pdl::OperandsOp //===----------------------------------------------------------------------===// -static LogicalResult verify(OperandsOp op) { return verifyHasBindingUse(op); } +LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); } //===----------------------------------------------------------------------===// // pdl::OperationOp @@ -230,15 +230,15 @@ return success(); } -static LogicalResult verify(OperationOp op) { - bool isWithinRewrite = isa(op->getParentOp()); - if (isWithinRewrite && !op.name()) - return op.emitOpError("must have an operation name when nested within " - "a `pdl.rewrite`"); - ArrayAttr attributeNames = op.attributeNames(); - auto attributeValues = op.attributes(); +LogicalResult OperationOp::verify() { + bool isWithinRewrite = isa((*this)->getParentOp()); + if (isWithinRewrite && !name()) + return emitOpError("must have an operation name when nested within " + "a `pdl.rewrite`"); + ArrayAttr attributeNames = attributeNamesAttr(); + auto attributeValues = attributes(); if (attributeNames.size() != attributeValues.size()) { - return op.emitOpError() + return emitOpError() << "expected the same number of attribute values and attribute " "names, got " << attributeNames.size() << " names and " << attributeValues.size() @@ -247,12 +247,12 @@ // If the operation is within a rewrite body and doesn't have type inference, // ensure that the result types can be resolved. - if (isWithinRewrite && !op.hasTypeInference()) { - if (failed(verifyResultTypesAreInferrable(op, op.types()))) + if (isWithinRewrite && !hasTypeInference()) { + if (failed(verifyResultTypesAreInferrable(*this, types()))) return failure(); } - return verifyHasBindingUse(op); + return verifyHasBindingUse(*this); } bool OperationOp::hasTypeInference() { @@ -269,12 +269,12 @@ // pdl::PatternOp //===----------------------------------------------------------------------===// -static LogicalResult verify(PatternOp pattern) { - Region &body = pattern.body(); +LogicalResult PatternOp::verify() { + Region &body = getBodyRegion(); Operation *term = body.front().getTerminator(); auto rewriteOp = dyn_cast(term); if (!rewriteOp) { - return pattern.emitOpError("expected body to terminate with `pdl.rewrite`") + return emitOpError("expected body to terminate with `pdl.rewrite`") .attachNote(term->getLoc()) .append("see terminator defined here"); } @@ -283,8 +283,7 @@ // dialect. WalkResult result = body.walk([&](Operation *op) -> WalkResult { if (!isa_and_nonnull(op->getDialect())) { - pattern - .emitOpError("expected only `pdl` operations within the pattern body") + emitOpError("expected only `pdl` operations within the pattern body") .attachNote(op->getLoc()) .append("see non-`pdl` operation defined here"); return WalkResult::interrupt(); @@ -296,8 +295,7 @@ // Check that there is at least one operation. if (body.front().getOps().empty()) - return pattern.emitOpError( - "the pattern must contain at least one `pdl.operation`"); + return emitOpError("the pattern must contain at least one `pdl.operation`"); // Determine if the operations within the pdl.pattern form a connected // component. This is determined by starting the search from the first @@ -333,8 +331,7 @@ first = false; } else if (!visited.count(&op)) { // For the subsequent operations, check if already visited. - return pattern - .emitOpError("the operations must form a connected component") + return emitOpError("the operations must form a connected component") .attachNote(op.getLoc()) .append("see a disconnected value / operation here"); } @@ -364,10 +361,10 @@ // pdl::ReplaceOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReplaceOp op) { - if (op.replOperation() && !op.replValues().empty()) - return op.emitOpError() << "expected no replacement values to be provided" - " when the replacement operation is present"; +LogicalResult ReplaceOp::verify() { + if (replOperation() && !replValues().empty()) + return emitOpError() << "expected no replacement values to be provided" + " when the replacement operation is present"; return success(); } @@ -392,11 +389,11 @@ p << " -> " << resultType; } -static LogicalResult verify(ResultsOp op) { - if (!op.index() && op.getType().isa()) { - return op.emitOpError() << "expected `pdl.range` result type when " - "no index is specified, but got: " - << op.getType(); +LogicalResult ResultsOp::verify() { + if (!index() && getType().isa()) { + return emitOpError() << "expected `pdl.range` result type when " + "no index is specified, but got: " + << getType(); } return success(); } @@ -405,13 +402,13 @@ // pdl::RewriteOp //===----------------------------------------------------------------------===// -static LogicalResult verify(RewriteOp op) { - Region &rewriteRegion = op.body(); +LogicalResult RewriteOp::verify() { + Region &rewriteRegion = body(); // Handle the case where the rewrite is external. - if (op.name()) { + if (name()) { if (!rewriteRegion.empty()) { - return op.emitOpError() + return emitOpError() << "expected rewrite region to be empty when rewrite is external"; } return success(); @@ -419,18 +416,18 @@ // Otherwise, check that the rewrite region only contains a single block. if (rewriteRegion.empty()) { - return op.emitOpError() << "expected rewrite region to be non-empty if " - "external name is not specified"; + return emitOpError() << "expected rewrite region to be non-empty if " + "external name is not specified"; } // Check that no additional arguments were provided. - if (!op.externalArgs().empty()) { - return op.emitOpError() << "expected no external arguments when the " - "rewrite is specified inline"; + if (!externalArgs().empty()) { + return emitOpError() << "expected no external arguments when the " + "rewrite is specified inline"; } - if (op.externalConstParams()) { - return op.emitOpError() << "expected no external constant parameters when " - "the rewrite is specified inline"; + if (externalConstParams()) { + return emitOpError() << "expected no external constant parameters when " + "the rewrite is specified inline"; } return success(); @@ -445,9 +442,9 @@ // pdl::TypeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(TypeOp op) { - if (!op.typeAttr()) - return verifyHasBindingUse(op); +LogicalResult TypeOp::verify() { + if (!typeAttr()) + return verifyHasBindingUse(*this); return success(); } @@ -455,9 +452,9 @@ // pdl::TypesOp //===----------------------------------------------------------------------===// -static LogicalResult verify(TypesOp op) { - if (!op.typesAttr()) - return verifyHasBindingUse(op); +LogicalResult TypesOp::verify() { + if (!typesAttr()) + return verifyHasBindingUse(*this); return success(); } diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -27,6 +27,21 @@ >(); } +template +static LogicalResult verifySwitchOp(OpT op) { + // Verify that the number of case destinations matches the number of case + // values. + size_t numDests = op.cases().size(); + size_t numValues = op.caseValues().size(); + if (numDests != numValues) { + return op.emitOpError( + "expected number of cases to match the number of case " + "values, got ") + << numDests << " but expected " << numValues; + } + return success(); +} + //===----------------------------------------------------------------------===// // pdl_interp::CreateOperationOp //===----------------------------------------------------------------------===// @@ -131,17 +146,17 @@ p.printSuccessor(op.successor()); } -static LogicalResult verify(ForEachOp op) { +LogicalResult ForEachOp::verify() { // Verify that the operation has exactly one argument. - if (op.region().getNumArguments() != 1) - return op.emitOpError("requires exactly one argument"); + if (region().getNumArguments() != 1) + return emitOpError("requires exactly one argument"); // Verify that the loop variable and the operand (value range) // have compatible types. - BlockArgument arg = op.getLoopVariable(); + BlockArgument arg = getLoopVariable(); Type rangeType = pdl::RangeType::get(arg.getType()); - if (rangeType != op.values().getType()) - return op.emitOpError("operand must be a range of loop variable type"); + if (rangeType != values().getType()) + return emitOpError("operand must be a range of loop variable type"); return success(); } @@ -156,6 +171,42 @@ return type.isa() ? pdl::RangeType::get(valueTy) : valueTy; } +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchAttributeOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); } + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchOperandCountOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); } + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchOperationNameOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); } + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchResultCountOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); } + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchTypeOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); } + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchTypesOp +//===----------------------------------------------------------------------===// + +LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); } + //===----------------------------------------------------------------------===// // TableGen Auto-Generated Op and Interface Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -65,29 +65,67 @@ return false; } -static LogicalResult verifyRegionOp(QuantizeRegionOp op) { +LogicalResult QuantizeRegionOp::verify() { // There are specifications for both inputs and outputs. - if (op.getNumOperands() != op.input_specs().size() || - op.getNumResults() != op.output_specs().size()) - return op.emitOpError( + if (getNumOperands() != input_specs().size() || + getNumResults() != output_specs().size()) + return emitOpError( "has unmatched operands/results number and spec attributes number"); // Verify that quantization specifications are valid. - for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) { + for (auto input : llvm::zip(getOperandTypes(), input_specs())) { Type inputType = std::get<0>(input); Attribute inputSpec = std::get<1>(input); if (!isValidQuantizationSpec(inputSpec, inputType)) { - return op.emitOpError() << "has incompatible specification " << inputSpec - << " and input type " << inputType; + return emitOpError() << "has incompatible specification " << inputSpec + << " and input type " << inputType; } } - for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) { + for (auto result : llvm::zip(getResultTypes(), output_specs())) { Type outputType = std::get<0>(result); Attribute outputSpec = std::get<1>(result); if (!isValidQuantizationSpec(outputSpec, outputType)) { - return op.emitOpError() << "has incompatible specification " << outputSpec - << " and output type " << outputType; + return emitOpError() << "has incompatible specification " << outputSpec + << " and output type " << outputType; + } + } + return success(); +} + +LogicalResult StatisticsOp::verify() { + auto tensorArg = arg().getType().dyn_cast(); + if (!tensorArg) + return emitOpError("arg needs to be tensor type."); + + // Verify layerStats attribute. + { + auto layerStatsType = layerStats().getType(); + if (!layerStatsType.getElementType().isa()) { + return emitOpError("layerStats must have a floating point element type"); + } + if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { + return emitOpError("layerStats must have shape [2]"); + } + } + // Verify axisStats (optional) attribute. + if (axisStats()) { + if (!axis()) + return emitOpError("axis must be specified for axisStats"); + + auto shape = tensorArg.getShape(); + auto argSliceSize = + std::accumulate(std::next(shape.begin(), *axis()), shape.end(), 1, + std::multiplies()); + + auto axisStatsType = axisStats()->getType(); + if (!axisStatsType.getElementType().isa()) { + return emitOpError("axisStats must have a floating point element type"); + } + if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 || + axisStatsType.getDimSize(0) != argSliceSize) { + return emitOpError("axisStats must have shape [N,2] " + "where N = the slice size defined by the axis dim"); } } return success(); diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -125,11 +125,11 @@ p.printOptionalAttrDict(op->getAttrs()); } -static LogicalResult verify(ExecuteRegionOp op) { - if (op.getRegion().empty()) - return op.emitOpError("region needs to have at least one block"); - if (op.getRegion().front().getNumArguments() > 0) - return op.emitOpError("region cannot have any arguments"); +LogicalResult ExecuteRegionOp::verify() { + if (getRegion().empty()) + return emitOpError("region needs to have at least one block"); + if (getRegion().front().getNumArguments() > 0) + return emitOpError("region cannot have any arguments"); return success(); } @@ -276,47 +276,47 @@ } } -static LogicalResult verify(ForOp op) { - if (auto cst = op.getStep().getDefiningOp()) +LogicalResult ForOp::verify() { + if (auto cst = getStep().getDefiningOp()) if (cst.value() <= 0) - return op.emitOpError("constant step operand must be positive"); + return emitOpError("constant step operand must be positive"); // Check that the body defines as single block argument for the induction // variable. - auto *body = op.getBody(); + auto *body = getBody(); if (!body->getArgument(0).getType().isIndex()) - return op.emitOpError( + return emitOpError( "expected body first argument to be an index argument for " "the induction variable"); - auto opNumResults = op.getNumResults(); + auto opNumResults = getNumResults(); if (opNumResults == 0) return success(); // If ForOp defines values, check that the number and types of // the defined values match ForOp initial iter operands and backedge // basic block arguments. - if (op.getNumIterOperands() != opNumResults) - return op.emitOpError( + if (getNumIterOperands() != opNumResults) + return emitOpError( "mismatch in number of loop-carried values and defined values"); - if (op.getNumRegionIterArgs() != opNumResults) - return op.emitOpError( + if (getNumRegionIterArgs() != opNumResults) + return emitOpError( "mismatch in number of basic block args and defined values"); - auto iterOperands = op.getIterOperands(); - auto iterArgs = op.getRegionIterArgs(); - auto opResults = op.getResults(); + auto iterOperands = getIterOperands(); + auto iterArgs = getRegionIterArgs(); + auto opResults = getResults(); unsigned i = 0; for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { if (std::get<0>(e).getType() != std::get<2>(e).getType()) - return op.emitOpError() << "types mismatch between " << i - << "th iter operand and defined value"; + return emitOpError() << "types mismatch between " << i + << "th iter operand and defined value"; if (std::get<1>(e).getType() != std::get<2>(e).getType()) - return op.emitOpError() << "types mismatch between " << i - << "th iter region arg and defined value"; + return emitOpError() << "types mismatch between " << i + << "th iter region arg and defined value"; i++; } - return RegionBranchOpInterface::verifyTypes(op); + return RegionBranchOpInterface::verifyTypes(*this); } /// Prints the initialization list in the form of @@ -1062,11 +1062,11 @@ build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder); } -static LogicalResult verify(IfOp op) { - if (op.getNumResults() != 0 && op.getElseRegion().empty()) - return op.emitOpError("must have an else block if defining values"); +LogicalResult IfOp::verify() { + if (getNumResults() != 0 && getElseRegion().empty()) + return emitOpError("must have an else block if defining values"); - return RegionBranchOpInterface::verifyTypes(op); + return RegionBranchOpInterface::verifyTypes(*this); } static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) { @@ -1723,32 +1723,31 @@ wrapper); } -static LogicalResult verify(ParallelOp op) { +LogicalResult ParallelOp::verify() { // Check that there is at least one value in lowerBound, upperBound and step. // It is sufficient to test only step, because it is ensured already that the // number of elements in lowerBound, upperBound and step are the same. - Operation::operand_range stepValues = op.getStep(); + Operation::operand_range stepValues = getStep(); if (stepValues.empty()) - return op.emitOpError( + return emitOpError( "needs at least one tuple element for lowerBound, upperBound and step"); // Check whether all constant step values are positive. for (Value stepValue : stepValues) if (auto cst = stepValue.getDefiningOp()) if (cst.value() <= 0) - return op.emitOpError("constant step operand must be positive"); + return emitOpError("constant step operand must be positive"); // Check that the body defines the same number of block arguments as the // number of tuple elements in step. - Block *body = op.getBody(); + Block *body = getBody(); if (body->getNumArguments() != stepValues.size()) - return op.emitOpError() - << "expects the same number of induction variables: " - << body->getNumArguments() - << " as bound and step values: " << stepValues.size(); + return emitOpError() << "expects the same number of induction variables: " + << body->getNumArguments() + << " as bound and step values: " << stepValues.size(); for (auto arg : body->getArguments()) if (!arg.getType().isIndex()) - return op.emitOpError( + return emitOpError( "expects arguments for the induction variable to be of index type"); // Check that the yield has no results @@ -1759,20 +1758,20 @@ // Check that the number of results is the same as the number of ReduceOps. SmallVector reductions(body->getOps()); - auto resultsSize = op.getResults().size(); + auto resultsSize = getResults().size(); auto reductionsSize = reductions.size(); - auto initValsSize = op.getInitVals().size(); + auto initValsSize = getInitVals().size(); if (resultsSize != reductionsSize) - return op.emitOpError() - << "expects number of results: " << resultsSize - << " to be the same as number of reductions: " << reductionsSize; + return emitOpError() << "expects number of results: " << resultsSize + << " to be the same as number of reductions: " + << reductionsSize; if (resultsSize != initValsSize) - return op.emitOpError() - << "expects number of results: " << resultsSize - << " to be the same as number of initial values: " << initValsSize; + return emitOpError() << "expects number of results: " << resultsSize + << " to be the same as number of initial values: " + << initValsSize; // Check that the types of the results and reductions are the same. - for (auto resultAndReduce : llvm::zip(op.getResults(), reductions)) { + for (auto resultAndReduce : llvm::zip(getResults(), reductions)) { auto resultType = std::get<0>(resultAndReduce).getType(); auto reduceOp = std::get<1>(resultAndReduce); auto reduceType = reduceOp.getOperand().getType(); @@ -2075,23 +2074,23 @@ body->getArgument(1)); } -static LogicalResult verify(ReduceOp op) { +LogicalResult ReduceOp::verify() { // The region of a ReduceOp has two arguments of the same type as its operand. - auto type = op.getOperand().getType(); - Block &block = op.getReductionOperator().front(); + auto type = getOperand().getType(); + Block &block = getReductionOperator().front(); if (block.empty()) - return op.emitOpError("the block inside reduce should not be empty"); + return emitOpError("the block inside reduce should not be empty"); if (block.getNumArguments() != 2 || llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) { return arg.getType() != type; })) - return op.emitOpError() - << "expects two arguments to reduce block of type " << type; + return emitOpError() << "expects two arguments to reduce block of type " + << type; // Check that the block is terminated by a ReduceReturnOp. if (!isa(block.getTerminator())) - return op.emitOpError("the block inside reduce should be terminated with a " - "'scf.reduce.return' op"); + return emitOpError("the block inside reduce should be terminated with a " + "'scf.reduce.return' op"); return success(); } @@ -2127,14 +2126,14 @@ // ReduceReturnOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReduceReturnOp op) { +LogicalResult ReduceReturnOp::verify() { // The type of the return value should be the same type as the type of the // operand of the enclosing ReduceOp. - auto reduceOp = cast(op->getParentOp()); + auto reduceOp = cast((*this)->getParentOp()); Type reduceType = reduceOp.getOperand().getType(); - if (reduceType != op.getResult().getType()) - return op.emitOpError() << "needs to have type " << reduceType - << " (the type of the enclosing ReduceOp)"; + if (reduceType != getResult().getType()) + return emitOpError() << "needs to have type " << reduceType + << " (the type of the enclosing ReduceOp)"; return success(); } @@ -2278,18 +2277,18 @@ return nullptr; } -static LogicalResult verify(scf::WhileOp op) { - if (failed(RegionBranchOpInterface::verifyTypes(op))) +LogicalResult scf::WhileOp::verify() { + if (failed(RegionBranchOpInterface::verifyTypes(*this))) return failure(); auto beforeTerminator = verifyAndGetTerminator( - op, op.getBefore(), + *this, getBefore(), "expects the 'before' region to terminate with 'scf.condition'"); if (!beforeTerminator) return failure(); auto afterTerminator = verifyAndGetTerminator( - op, op.getAfter(), + *this, getAfter(), "expects the 'after' region to terminate with 'scf.yield'"); return success(afterTerminator != nullptr); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -419,6 +419,10 @@ result.addTypes(assumingTypes); } +LogicalResult AssumingOp::verify() { + return RegionBranchOpInterface::verifyTypes(*this); +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -449,6 +453,8 @@ operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); } +LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// @@ -529,10 +535,10 @@ return BoolAttr::get(getContext(), true); } -static LogicalResult verify(AssumingAllOp op) { +LogicalResult AssumingAllOp::verify() { // Ensure that AssumingAllOp contains at least one operand - if (op.getNumOperands() == 0) - return op.emitOpError("no operands specified"); + if (getNumOperands() == 0) + return emitOpError("no operands specified"); return success(); } @@ -575,8 +581,8 @@ return builder.getIndexTensorAttr(resultShape); } -static LogicalResult verify(BroadcastOp op) { - return verifyShapeOrExtentTensorOp(op); +LogicalResult BroadcastOp::verify() { + return verifyShapeOrExtentTensorOp(*this); } namespace { @@ -912,10 +918,10 @@ return nullptr; } -static LogicalResult verify(CstrBroadcastableOp op) { +LogicalResult CstrBroadcastableOp::verify() { // Ensure that AssumingAllOp contains at least one operand - if (op.getNumOperands() < 2) - return op.emitOpError("required at least 2 input shapes"); + if (getNumOperands() < 2) + return emitOpError("required at least 2 input shapes"); return success(); } @@ -1016,6 +1022,8 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // ShapeEqOp //===----------------------------------------------------------------------===// @@ -1172,6 +1180,8 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // IsBroadcastableOp //===----------------------------------------------------------------------===// @@ -1298,6 +1308,8 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// @@ -1333,6 +1345,10 @@ return eachHasOnlyOneOfTypes(l, r); } +LogicalResult shape::NumElementsOp::verify() { + return verifySizeOrIndexOp(*this); +} + //===----------------------------------------------------------------------===// // MaxOp //===----------------------------------------------------------------------===// @@ -1429,6 +1445,9 @@ // SizeType is compatible with IndexType. return eachHasOnlyOneOfTypes(l, r); } + +LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } + //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// @@ -1535,6 +1554,10 @@ return false; } +LogicalResult shape::ShapeOfOp::verify() { + return verifyShapeOrExtentTensorOp(*this); +} + //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// @@ -1556,18 +1579,17 @@ // YieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(shape::YieldOp op) { - auto *parentOp = op->getParentOp(); +LogicalResult shape::YieldOp::verify() { + auto *parentOp = (*this)->getParentOp(); auto results = parentOp->getResults(); - auto operands = op.getOperands(); + auto operands = getOperands(); - if (parentOp->getNumResults() != op.getNumOperands()) - return op.emitOpError() << "number of operands does not match number of " - "results of its parent"; + if (parentOp->getNumResults() != getNumOperands()) + return emitOpError() << "number of operands does not match number of " + "results of its parent"; for (auto e : llvm::zip(results, operands)) if (std::get<0>(e).getType() != std::get<1>(e).getType()) - return op.emitOpError() - << "types mismatch between yield op and its parent"; + return emitOpError() << "types mismatch between yield op and its parent"; return success(); } @@ -1639,41 +1661,42 @@ } } -static LogicalResult verify(ReduceOp op) { +LogicalResult ReduceOp::verify() { // Verify block arg types. - Block &block = op.getRegion().front(); + Block &block = getRegion().front(); // The block takes index, extent, and aggregated values as arguments. - auto blockArgsCount = op.getInitVals().size() + 2; + auto blockArgsCount = getInitVals().size() + 2; if (block.getNumArguments() != blockArgsCount) - return op.emitOpError() << "ReduceOp body is expected to have " - << blockArgsCount << " arguments"; + return emitOpError() << "ReduceOp body is expected to have " + << blockArgsCount << " arguments"; // The first block argument is the index and must always be of type `index`. if (!block.getArgument(0).getType().isa()) - return op.emitOpError( + return emitOpError( "argument 0 of ReduceOp body is expected to be of IndexType"); // The second block argument is the extent and must be of type `size` or // `index`, depending on whether the reduce operation is applied to a shape or // to an extent tensor. Type extentTy = block.getArgument(1).getType(); - if (op.getShape().getType().isa()) { + if (getShape().getType().isa()) { if (!extentTy.isa()) - return op.emitOpError("argument 1 of ReduceOp body is expected to be of " - "SizeType if the ReduceOp operates on a ShapeType"); + return emitOpError("argument 1 of ReduceOp body is expected to be of " + "SizeType if the ReduceOp operates on a ShapeType"); } else { if (!extentTy.isa()) - return op.emitOpError( + return emitOpError( "argument 1 of ReduceOp body is expected to be of IndexType if the " "ReduceOp operates on an extent tensor"); } - for (const auto &type : llvm::enumerate(op.getInitVals())) + for (const auto &type : llvm::enumerate(getInitVals())) if (block.getArgument(type.index() + 2).getType() != type.value().getType()) - return op.emitOpError() - << "type mismatch between argument " << type.index() + 2 - << " of ReduceOp body and initial value " << type.index(); + return emitOpError() << "type mismatch between argument " + << type.index() + 2 + << " of ReduceOp body and initial value " + << type.index(); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -209,53 +209,51 @@ return failure(); } -static LogicalResult verify(NewOp op) { - if (!getSparseTensorEncoding(op.result().getType())) - return op.emitError("expected a sparse tensor result"); +LogicalResult NewOp::verify() { + if (!getSparseTensorEncoding(result().getType())) + return emitError("expected a sparse tensor result"); return success(); } -static LogicalResult verify(InitOp op) { - if (!getSparseTensorEncoding(op.result().getType())) - return op.emitError("expected a sparse tensor result"); - RankedTensorType ttp = op.getType().cast(); +LogicalResult InitOp::verify() { + if (!getSparseTensorEncoding(result().getType())) + return emitError("expected a sparse tensor result"); + RankedTensorType ttp = getType().cast(); unsigned rank = ttp.getRank(); - if (rank != op.sizes().size()) - return op.emitError("unexpected mismatch between tensor rank and sizes: ") - << rank << " vs. " << op.sizes().size(); + if (rank != sizes().size()) + return emitError("unexpected mismatch between tensor rank and sizes: ") + << rank << " vs. " << sizes().size(); auto shape = ttp.getShape(); for (unsigned i = 0; i < rank; i++) { if (shape[i] == ShapedType::kDynamicSize) continue; IntegerAttr constantAttr; - if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) || + if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) || constantAttr.getInt() != shape[i]) { - return op.emitError("unexpected mismatch with static dimension size ") + return emitError("unexpected mismatch with static dimension size ") << shape[i]; } } return success(); } -static LogicalResult verify(ConvertOp op) { - if (auto tp1 = op.source().getType().dyn_cast()) { - if (auto tp2 = op.dest().getType().dyn_cast()) { +LogicalResult ConvertOp::verify() { + if (auto tp1 = source().getType().dyn_cast()) { + if (auto tp2 = dest().getType().dyn_cast()) { if (tp1.getRank() != tp2.getRank()) - return op.emitError("unexpected conversion mismatch in rank"); + return emitError("unexpected conversion mismatch in rank"); auto shape1 = tp1.getShape(); auto shape2 = tp2.getShape(); // Accept size matches between the source and the destination type // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). - for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { + for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) - return op.emitError("unexpected conversion mismatch in dimension ") - << d; - } + return emitError("unexpected conversion mismatch in dimension ") << d; return success(); } } - return op.emitError("unexpected type in convert"); + return emitError("unexpected type in convert"); } OpFoldResult ConvertOp::fold(ArrayRef operands) { @@ -264,35 +262,35 @@ return {}; } -static LogicalResult verify(ToPointersOp op) { - if (auto e = getSparseTensorEncoding(op.tensor().getType())) { - if (failed(isInBounds(op.dim(), op.tensor()))) - return op.emitError("requested pointers dimension out of bounds"); - if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) - return op.emitError("unexpected type for pointers"); +LogicalResult ToPointersOp::verify() { + if (auto e = getSparseTensorEncoding(tensor().getType())) { + if (failed(isInBounds(dim(), tensor()))) + return emitError("requested pointers dimension out of bounds"); + if (failed(isMatchingWidth(result(), e.getPointerBitWidth()))) + return emitError("unexpected type for pointers"); return success(); } - return op.emitError("expected a sparse tensor to get pointers"); + return emitError("expected a sparse tensor to get pointers"); } -static LogicalResult verify(ToIndicesOp op) { - if (auto e = getSparseTensorEncoding(op.tensor().getType())) { - if (failed(isInBounds(op.dim(), op.tensor()))) - return op.emitError("requested indices dimension out of bounds"); - if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) - return op.emitError("unexpected type for indices"); +LogicalResult ToIndicesOp::verify() { + if (auto e = getSparseTensorEncoding(tensor().getType())) { + if (failed(isInBounds(dim(), tensor()))) + return emitError("requested indices dimension out of bounds"); + if (failed(isMatchingWidth(result(), e.getIndexBitWidth()))) + return emitError("unexpected type for indices"); return success(); } - return op.emitError("expected a sparse tensor to get indices"); + return emitError("expected a sparse tensor to get indices"); } -static LogicalResult verify(ToValuesOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor to get values"); - RankedTensorType ttp = op.tensor().getType().cast(); - MemRefType mtp = op.result().getType().cast(); +LogicalResult ToValuesOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor to get values"); + RankedTensorType ttp = tensor().getType().cast(); + MemRefType mtp = result().getType().cast(); if (ttp.getElementType() != mtp.getElementType()) - return op.emitError("unexpected mismatch in element types"); + return emitError("unexpected mismatch in element types"); return success(); } @@ -300,39 +298,39 @@ // TensorDialect Management Operations. //===----------------------------------------------------------------------===// -static LogicalResult verify(LexInsertOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor for insertion"); +LogicalResult LexInsertOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor for insertion"); return success(); } -static LogicalResult verify(ExpandOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor for expansion"); +LogicalResult ExpandOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor for expansion"); return success(); } -static LogicalResult verify(CompressOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor for compression"); +LogicalResult CompressOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor for compression"); return success(); } -static LogicalResult verify(LoadOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor to materialize"); +LogicalResult LoadOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor to materialize"); return success(); } -static LogicalResult verify(ReleaseOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor to release"); +LogicalResult ReleaseOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor to release"); return success(); } -static LogicalResult verify(OutOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor for output"); +LogicalResult OutOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor for output"); return success(); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -612,31 +612,31 @@ /// The constant op requires an attribute, and furthermore requires that it /// matches the return type. -static LogicalResult verify(ConstantOp &op) { - auto value = op.getValue(); +LogicalResult ConstantOp::verify() { + auto value = getValue(); if (!value) - return op.emitOpError("requires a 'value' attribute"); + return emitOpError("requires a 'value' attribute"); - Type type = op.getType(); + Type type = getType(); if (!value.getType().isa() && type != value.getType()) - return op.emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; + return emitOpError() << "requires attribute's type (" << value.getType() + << ") to match op's return type (" << type << ")"; if (type.isa()) { auto fnAttr = value.dyn_cast(); if (!fnAttr) - return op.emitOpError("requires 'value' to be a function reference"); + return emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. - auto fn = - op->getParentOfType().lookupSymbol(fnAttr.getValue()); + auto fn = (*this)->getParentOfType().lookupSymbol( + fnAttr.getValue()); if (!fn) - return op.emitOpError() - << "reference to undefined function '" << fnAttr.getValue() << "'"; + return emitOpError() << "reference to undefined function '" + << fnAttr.getValue() << "'"; // Check that the referenced function has the correct type. if (fn.getType() != type) - return op.emitOpError("reference to function with mismatched type"); + return emitOpError("reference to function with mismatched type"); return success(); } @@ -644,7 +644,7 @@ if (type.isa() && value.isa()) return success(); - return op.emitOpError("unsupported 'value' attribute: ") << value; + return emitOpError("unsupported 'value' attribute: ") << value; } OpFoldResult ConstantOp::fold(ArrayRef operands) { @@ -676,23 +676,23 @@ // ReturnOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReturnOp op) { - auto function = cast(op->getParentOp()); +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); - if (op.getNumOperands() != results.size()) - return op.emitOpError("has ") - << op.getNumOperands() << " operands, but enclosing function (@" + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) - if (op.getOperand(i).getType() != results[i]) - return op.emitError() - << "type of return operand " << i << " (" - << op.getOperand(i).getType() - << ") doesn't match function result type (" << results[i] << ")" - << " in function @" << function.getName(); + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); return success(); } @@ -843,24 +843,23 @@ parser.getNameLoc(), result.operands); } -static LogicalResult verify(SelectOp op) { - Type conditionType = op.getCondition().getType(); +LogicalResult SelectOp::verify() { + Type conditionType = getCondition().getType(); if (conditionType.isSignlessInteger(1)) return success(); // If the result type is a vector or tensor, the type can be a mask with the // same elements. - Type resultType = op.getType(); + Type resultType = getType(); if (!resultType.isa()) - return op.emitOpError() - << "expected condition to be a signless i1, but got " - << conditionType; + return emitOpError() << "expected condition to be a signless i1, but got " + << conditionType; Type shapedConditionType = getI1SameShape(resultType); if (conditionType != shapedConditionType) - return op.emitOpError() - << "expected condition type to have the same shape " - "as the result type, expected " - << shapedConditionType << ", but got " << conditionType; + return emitOpError() << "expected condition type to have the same shape " + "as the result type, expected " + << shapedConditionType << ", but got " + << conditionType; return success(); } @@ -868,11 +867,10 @@ // SplatOp //===----------------------------------------------------------------------===// -static LogicalResult verify(SplatOp op) { +LogicalResult SplatOp::verify() { // TODO: we could replace this by a trait. - if (op.getOperand().getType() != - op.getType().cast().getElementType()) - return op.emitError("operand should be of elemental type of result type"); + if (getOperand().getType() != getType().cast().getElementType()) + return emitError("operand should be of elemental type of result type"); return success(); } @@ -995,26 +993,26 @@ p.printNewline(); } -static LogicalResult verify(SwitchOp op) { - auto caseValues = op.getCaseValues(); - auto caseDestinations = op.getCaseDestinations(); +LogicalResult SwitchOp::verify() { + auto caseValues = getCaseValues(); + auto caseDestinations = getCaseDestinations(); if (!caseValues && caseDestinations.empty()) return success(); - Type flagType = op.getFlag().getType(); + Type flagType = getFlag().getType(); Type caseValueType = caseValues->getType().getElementType(); if (caseValueType != flagType) - return op.emitOpError() - << "'flag' type (" << flagType << ") should match case value type (" - << caseValueType << ")"; + return emitOpError() << "'flag' type (" << flagType + << ") should match case value type (" << caseValueType + << ")"; if (caseValues && caseValues->size() != static_cast(caseDestinations.size())) - return op.emitOpError() << "number of case values (" << caseValues->size() - << ") should match number of " - "case destinations (" - << caseDestinations.size() << ")"; + return emitOpError() << "number of case values (" << caseValues->size() + << ") should match number of " + "case destinations (" + << caseDestinations.size() << ")"; return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -228,17 +228,17 @@ return {}; } -static LogicalResult verify(DimOp op) { +LogicalResult DimOp::verify() { // Assume unknown index to be in range. - Optional index = op.getConstantIndex(); + Optional index = getConstantIndex(); if (!index.hasValue()) return success(); // Check that constant index is not knowingly out of range. - auto type = op.source().getType(); + auto type = source().getType(); if (auto tensorType = type.dyn_cast()) { if (index.getValue() >= tensorType.getRank()) - return op.emitOpError("index is out of range"); + return emitOpError("index is out of range"); } else if (type.isa()) { // Assume index to be in range. } else { @@ -328,11 +328,11 @@ // ExtractOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ExtractOp op) { +LogicalResult ExtractOp::verify() { // Verify the # indices match if we have a ranked type. - if (auto tensorType = op.tensor().getType().dyn_cast()) - if (tensorType.getRank() != static_cast(op.indices().size())) - return op.emitOpError("incorrect number of indices for extract_element"); + if (auto tensorType = tensor().getType().dyn_cast()) + if (tensorType.getRank() != static_cast(indices().size())) + return emitOpError("incorrect number of indices for extract_element"); return success(); } @@ -480,11 +480,11 @@ // InsertOp //===----------------------------------------------------------------------===// -static LogicalResult verify(InsertOp op) { +LogicalResult InsertOp::verify() { // Verify the # indices match if we have a ranked type. - if (auto destType = op.dest().getType().dyn_cast()) - if (destType.getRank() != static_cast(op.indices().size())) - return op.emitOpError("incorrect number of indices"); + if (auto destType = dest().getType().dyn_cast()) + if (destType.getRank() != static_cast(indices().size())) + return emitOpError("incorrect number of indices"); return success(); } @@ -502,27 +502,26 @@ // GenerateOp //===----------------------------------------------------------------------===// -static LogicalResult verify(GenerateOp op) { +LogicalResult GenerateOp::verify() { // Ensure that the tensor type has as many dynamic dimensions as are specified // by the operands. - RankedTensorType resultTy = op.getType().cast(); - if (op.getNumOperands() != resultTy.getNumDynamicDims()) - return op.emitError("must have as many index operands as dynamic extents " - "in the result type"); + RankedTensorType resultTy = getType().cast(); + if (getNumOperands() != resultTy.getNumDynamicDims()) + return emitError("must have as many index operands as dynamic extents " + "in the result type"); // Ensure that region arguments span the index space. - if (!llvm::all_of(op.body().getArgumentTypes(), + if (!llvm::all_of(body().getArgumentTypes(), [](Type ty) { return ty.isIndex(); })) - return op.emitError("all body arguments must be index"); - if (op.body().getNumArguments() != resultTy.getRank()) - return op.emitError("must have one body argument per input dimension"); + return emitError("all body arguments must be index"); + if (body().getNumArguments() != resultTy.getRank()) + return emitError("must have one body argument per input dimension"); // Ensure that the region yields an element of the right type. - auto yieldOp = - llvm::cast(op.body().getBlocks().front().getTerminator()); + auto yieldOp = cast(body().getBlocks().front().getTerminator()); if (yieldOp.value().getType() != resultTy.getElementType()) - return op.emitOpError( + return emitOpError( "body must be terminated with a `yield` operation of the tensor " "element type"); @@ -686,16 +685,15 @@ return numElements; } -static LogicalResult verify(ReshapeOp op) { - TensorType operandType = op.source().getType().cast(); - TensorType resultType = op.result().getType().cast(); +LogicalResult ReshapeOp::verify() { + TensorType operandType = source().getType().cast(); + TensorType resultType = result().getType().cast(); if (operandType.getElementType() != resultType.getElementType()) - return op.emitOpError("element types of source and destination tensor " - "types should be the same"); + return emitOpError("element types of source and destination tensor " + "types should be the same"); - int64_t shapeSize = - op.shape().getType().cast().getDimSize(0); + int64_t shapeSize = shape().getType().cast().getDimSize(0); auto resultRankedType = resultType.dyn_cast(); auto operandRankedType = operandType.dyn_cast(); @@ -703,14 +701,14 @@ if (operandRankedType && resultRankedType.hasStaticShape() && operandRankedType.hasStaticShape()) { if (getNumElements(operandRankedType) != getNumElements(resultRankedType)) - return op.emitOpError("source and destination tensor should have the " - "same number of elements"); + return emitOpError("source and destination tensor should have the " + "same number of elements"); } if (ShapedType::isDynamic(shapeSize)) - return op.emitOpError("cannot use shape operand with dynamic length to " - "reshape to statically-ranked tensor type"); + return emitOpError("cannot use shape operand with dynamic length to " + "reshape to statically-ranked tensor type"); if (shapeSize != resultRankedType.getRank()) - return op.emitOpError( + return emitOpError( "length of shape operand differs from the result's tensor rank"); } return success(); @@ -814,12 +812,12 @@ return success(); } -static LogicalResult verify(ExpandShapeOp op) { - return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType()); +LogicalResult ExpandShapeOp::verify() { + return verifyTensorReshapeOp(*this, getResultType(), getSrcType()); } -static LogicalResult verify(CollapseShapeOp op) { - return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType()); +LogicalResult CollapseShapeOp::verify() { + return verifyTensorReshapeOp(*this, getSrcType(), getResultType()); } namespace { @@ -1052,14 +1050,12 @@ } /// Verifier for ExtractSliceOp. -static LogicalResult verify(ExtractSliceOp op) { +LogicalResult ExtractSliceOp::verify() { // Verify result type against inferred type. - auto expectedType = - ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(), - op.getMixedSizes(), op.getMixedStrides()); - auto result = - isRankReducedType(expectedType.cast(), op.getType()); - return produceSliceErrorMsg(result, op, expectedType); + auto expectedType = ExtractSliceOp::inferResultType( + getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides()); + auto result = isRankReducedType(expectedType.cast(), getType()); + return produceSliceErrorMsg(result, *this, expectedType); } /// Infer the canonical type of the result of an extract_slice op. Returns a @@ -1308,16 +1304,16 @@ } /// Verifier for InsertSliceOp. -static LogicalResult verify(InsertSliceOp op) { +LogicalResult InsertSliceOp::verify() { // insert_slice is the inverse of extract_slice, use the same type inference. auto expectedType = ExtractSliceOp::inferRankReducedResultType( - op.getSourceType().getRank(), op.getType(), - extractFromI64ArrayAttr(op.static_offsets()), - extractFromI64ArrayAttr(op.static_sizes()), - extractFromI64ArrayAttr(op.static_strides())); + getSourceType().getRank(), getType(), + extractFromI64ArrayAttr(static_offsets()), + extractFromI64ArrayAttr(static_sizes()), + extractFromI64ArrayAttr(static_strides())); auto result = - isRankReducedType(expectedType.cast(), op.getSourceType()); - return produceSliceErrorMsg(result, op, expectedType); + isRankReducedType(expectedType.cast(), getSourceType()); + return produceSliceErrorMsg(result, *this, expectedType); } /// If we have two consecutive InsertSliceOp writing to the same slice, we @@ -1569,40 +1565,40 @@ return success(); } -static LogicalResult verify(PadOp op) { - auto sourceType = op.source().getType().cast(); - auto resultType = op.result().getType().cast(); - auto expectedType = PadOp::inferResultType( - sourceType, extractFromI64ArrayAttr(op.static_low()), - extractFromI64ArrayAttr(op.static_high())); +LogicalResult PadOp::verify() { + auto sourceType = source().getType().cast(); + auto resultType = result().getType().cast(); + auto expectedType = + PadOp::inferResultType(sourceType, extractFromI64ArrayAttr(static_low()), + extractFromI64ArrayAttr(static_high())); for (int i = 0, e = sourceType.getRank(); i < e; ++i) { if (resultType.getDimSize(i) == expectedType.getDimSize(i)) continue; if (expectedType.isDynamicDim(i)) continue; - return op.emitError("specified type ") + return emitError("specified type ") << resultType << " does not match the inferred type " << expectedType; } - auto ®ion = op.region(); + auto ®ion = getRegion(); unsigned rank = resultType.getRank(); Block &block = region.front(); if (block.getNumArguments() != rank) - return op.emitError("expected the block to have ") << rank << " arguments"; + return emitError("expected the block to have ") << rank << " arguments"; // Note: the number and type of yield values are checked in the YieldOp. for (const auto &en : llvm::enumerate(block.getArgumentTypes())) { if (!en.value().isIndex()) - return op.emitOpError("expected block argument ") + return emitOpError("expected block argument ") << (en.index() + 1) << " to be an index"; } // Ensure that the region yields an element of the right type. auto yieldOp = llvm::cast(block.getTerminator()); if (yieldOp.value().getType() != - op.getType().cast().getElementType()) - return op.emitOpError("expected yield type to match shape element type"); + getType().cast().getElementType()) + return emitOpError("expected yield type to match shape element type"); return success(); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -701,9 +701,9 @@ return success(); } -static LogicalResult verifyAveragePoolOp(tosa::AvgPool2dOp op) { - auto inputETy = op.input().getType().cast().getElementType(); - auto resultETy = op.getType().cast().getElementType(); +LogicalResult tosa::AvgPool2dOp::verify() { + auto inputETy = input().getType().cast().getElementType(); + auto resultETy = getType().cast().getElementType(); if (auto quantType = inputETy.dyn_cast()) inputETy = quantType.getStorageType(); @@ -718,7 +718,7 @@ if (inputETy.isInteger(16) && resultETy.isInteger(16)) return success(); - return op.emitOpError("input/output element types are incompatible."); + return emitOpError("input/output element types are incompatible."); } //===----------------------------------------------------------------------===// @@ -1010,6 +1010,8 @@ return success(); } +LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); } + LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1632,6 +1634,8 @@ return success(); } +LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); } + LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1708,6 +1712,8 @@ return success(); } +LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); } + LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1800,6 +1806,8 @@ return success(); } +LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); } + LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -354,15 +354,15 @@ CombiningKindAttr::get(kind, builder.getContext())); } -static LogicalResult verify(MultiDimReductionOp op) { - auto reductionMask = op.getReductionMask(); +LogicalResult MultiDimReductionOp::verify() { + auto reductionMask = getReductionMask(); auto targetType = MultiDimReductionOp::inferDestType( - op.getSourceVectorType().getShape(), reductionMask, - op.getSourceVectorType().getElementType()); + getSourceVectorType().getShape(), reductionMask, + getSourceVectorType().getElementType()); // TODO: update to support 0-d vectors when available. - if (targetType != op.getDestType()) - return op.emitError("invalid output vector type: ") - << op.getDestType() << " (expected: " << targetType << ")"; + if (targetType != getDestType()) + return emitError("invalid output vector type: ") + << getDestType() << " (expected: " << targetType << ")"; return success(); } @@ -377,29 +377,29 @@ // ReductionOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReductionOp op) { +LogicalResult ReductionOp::verify() { // Verify for 1-D vector. - int64_t rank = op.getVectorType().getRank(); + int64_t rank = getVectorType().getRank(); if (rank != 1) - return op.emitOpError("unsupported reduction rank: ") << rank; + return emitOpError("unsupported reduction rank: ") << rank; // Verify supported reduction kind. - StringRef strKind = op.kind(); + StringRef strKind = kind(); auto maybeKind = symbolizeCombiningKind(strKind); if (!maybeKind) - return op.emitOpError("unknown reduction kind: ") << strKind; + return emitOpError("unknown reduction kind: ") << strKind; - Type eltType = op.dest().getType(); + Type eltType = dest().getType(); if (!isSupportedCombiningKind(*maybeKind, eltType)) - return op.emitOpError("unsupported reduction type '") - << eltType << "' for kind '" << op.kind() << "'"; + return emitOpError("unsupported reduction type '") + << eltType << "' for kind '" << strKind << "'"; // Verify optional accumulator. - if (!op.acc().empty()) { + if (!acc().empty()) { if (strKind != "add" && strKind != "mul") - return op.emitOpError("no accumulator for reduction kind: ") << strKind; + return emitOpError("no accumulator for reduction kind: ") << strKind; if (!eltType.isa()) - return op.emitOpError("no accumulator for type: ") << eltType; + return emitOpError("no accumulator for type: ") << eltType; } return success(); @@ -676,78 +676,78 @@ return success(); } -static LogicalResult verify(ContractionOp op) { - auto lhsType = op.getLhsType(); - auto rhsType = op.getRhsType(); - auto accType = op.getAccType(); - auto resType = op.getResultType(); +LogicalResult ContractionOp::verify() { + auto lhsType = getLhsType(); + auto rhsType = getRhsType(); + auto accType = getAccType(); + auto resType = getResultType(); // Verify that an indexing map was specified for each vector operand. - if (op.indexing_maps().size() != 3) - return op.emitOpError("expected an indexing map for each vector operand"); + if (indexing_maps().size() != 3) + return emitOpError("expected an indexing map for each vector operand"); // Verify that each index map has 'numIterators' inputs, no symbols, and // that the number of map outputs equals the rank of its associated // vector operand. - unsigned numIterators = op.iterator_types().getValue().size(); - for (const auto &it : llvm::enumerate(op.indexing_maps())) { + unsigned numIterators = iterator_types().getValue().size(); + for (const auto &it : llvm::enumerate(indexing_maps())) { auto index = it.index(); auto map = it.value().cast().getValue(); if (map.getNumSymbols() != 0) - return op.emitOpError("expected indexing map ") + return emitOpError("expected indexing map ") << index << " to have no symbols"; - auto vectorType = op.getOperand(index).getType().dyn_cast(); + auto vectorType = getOperand(index).getType().dyn_cast(); unsigned rank = vectorType ? vectorType.getShape().size() : 0; // Verify that the map has the right number of inputs, outputs, and indices. // This also correctly accounts for (..) -> () for rank-0 results. if (map.getNumDims() != numIterators) - return op.emitOpError("expected indexing map ") + return emitOpError("expected indexing map ") << index << " to have " << numIterators << " number of inputs"; if (map.getNumResults() != rank) - return op.emitOpError("expected indexing map ") + return emitOpError("expected indexing map ") << index << " to have " << rank << " number of outputs"; if (!map.isProjectedPermutation()) - return op.emitOpError("expected indexing map ") + return emitOpError("expected indexing map ") << index << " to be a projected permutation of its inputs"; } - auto contractingDimMap = op.getContractingDimMap(); - auto batchDimMap = op.getBatchDimMap(); + auto contractingDimMap = getContractingDimMap(); + auto batchDimMap = getBatchDimMap(); // Verify at least one contracting dimension pair was specified. if (contractingDimMap.empty()) - return op.emitOpError("expected at least one contracting dimension pair"); + return emitOpError("expected at least one contracting dimension pair"); // Verify contracting dimension map was properly constructed. if (!verifyDimMap(lhsType, rhsType, contractingDimMap)) - return op.emitOpError("invalid contracting dimension map"); + return emitOpError("invalid contracting dimension map"); // Verify batch dimension map was properly constructed. if (!verifyDimMap(lhsType, rhsType, batchDimMap)) - return op.emitOpError("invalid batch dimension map"); + return emitOpError("invalid batch dimension map"); // Verify 'accType' and 'resType' shape. - if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType, + if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType, contractingDimMap, batchDimMap))) return failure(); // Verify that either two vector masks are set or none are set. - auto lhsMaskType = op.getLHSVectorMaskType(); - auto rhsMaskType = op.getRHSVectorMaskType(); + auto lhsMaskType = getLHSVectorMaskType(); + auto rhsMaskType = getRHSVectorMaskType(); if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) - return op.emitOpError("invalid number of vector masks specified"); + return emitOpError("invalid number of vector masks specified"); if (lhsMaskType && rhsMaskType) { // Verify mask rank == argument rank. if (lhsMaskType.getShape().size() != lhsType.getShape().size() || rhsMaskType.getShape().size() != rhsType.getShape().size()) - return op.emitOpError("invalid vector mask rank"); + return emitOpError("invalid vector mask rank"); } // Verify supported combining kind. auto vectorType = resType.dyn_cast(); auto elementType = vectorType ? vectorType.getElementType() : resType; - if (!isSupportedCombiningKind(op.kind(), elementType)) - return op.emitOpError("unsupported contraction type"); + if (!isSupportedCombiningKind(kind(), elementType)) + return emitOpError("unsupported contraction type"); return success(); } @@ -923,17 +923,17 @@ result.addTypes(source.getType().cast().getElementType()); } -static LogicalResult verify(vector::ExtractElementOp op) { - VectorType vectorType = op.getVectorType(); +LogicalResult vector::ExtractElementOp::verify() { + VectorType vectorType = getVectorType(); if (vectorType.getRank() == 0) { - if (op.position()) - return op.emitOpError("expected position to be empty with 0-D vector"); + if (position()) + return emitOpError("expected position to be empty with 0-D vector"); return success(); } if (vectorType.getRank() != 1) - return op.emitOpError("unexpected >1 vector rank"); - if (!op.position()) - return op.emitOpError("expected position for 1-D vector"); + return emitOpError("unexpected >1 vector rank"); + if (!position()) + return emitOpError("expected position for 1-D vector"); return success(); } @@ -1003,16 +1003,16 @@ parser.addTypeToList(resType, result.types)); } -static LogicalResult verify(vector::ExtractOp op) { - auto positionAttr = op.position().getValue(); - if (positionAttr.size() > static_cast(op.getVectorType().getRank())) - return op.emitOpError( +LogicalResult vector::ExtractOp::verify() { + auto positionAttr = position().getValue(); + if (positionAttr.size() > static_cast(getVectorType().getRank())) + return emitOpError( "expected position attribute of rank smaller than vector rank"); for (const auto &en : llvm::enumerate(positionAttr)) { auto attr = en.value().dyn_cast(); if (!attr || attr.getInt() < 0 || - attr.getInt() >= op.getVectorType().getDimSize(en.index())) - return op.emitOpError("expected position attribute #") + attr.getInt() >= getVectorType().getDimSize(en.index())) + return emitOpError("expected position attribute #") << (en.index() + 1) << " to be a non-negative integer smaller than the corresponding " "vector dimension"; @@ -1565,24 +1565,21 @@ ExtractMapOp::build(builder, result, resultType, vector, ids); } -static LogicalResult verify(ExtractMapOp op) { - if (op.getSourceVectorType().getRank() != op.getResultType().getRank()) - return op.emitOpError( - "expected source and destination vectors of same rank"); +LogicalResult ExtractMapOp::verify() { + if (getSourceVectorType().getRank() != getResultType().getRank()) + return emitOpError("expected source and destination vectors of same rank"); unsigned numId = 0; - for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) { - if (op.getSourceVectorType().getDimSize(i) % - op.getResultType().getDimSize(i) != + for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; ++i) { + if (getSourceVectorType().getDimSize(i) % getResultType().getDimSize(i) != 0) - return op.emitOpError("source vector dimensions must be a multiple of " - "destination vector dimensions"); - if (op.getSourceVectorType().getDimSize(i) != - op.getResultType().getDimSize(i)) + return emitOpError("source vector dimensions must be a multiple of " + "destination vector dimensions"); + if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) numId++; } - if (numId != op.ids().size()) - return op.emitOpError("expected number of ids must match the number of " - "dimensions distributed"); + if (numId != ids().size()) + return emitOpError("expected number of ids must match the number of " + "dimensions distributed"); return success(); } @@ -1666,19 +1663,19 @@ return BroadcastableToResult::Success; } -static LogicalResult verify(BroadcastOp op) { +LogicalResult BroadcastOp::verify() { std::pair mismatchingDims; - BroadcastableToResult res = isBroadcastableTo( - op.getSourceType(), op.getVectorType(), &mismatchingDims); + BroadcastableToResult res = + isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims); if (res == BroadcastableToResult::Success) return success(); if (res == BroadcastableToResult::SourceRankHigher) - return op.emitOpError("source rank higher than destination rank"); + return emitOpError("source rank higher than destination rank"); if (res == BroadcastableToResult::DimensionMismatch) - return op.emitOpError("dimension mismatch (") + return emitOpError("dimension mismatch (") << mismatchingDims.first << " vs. " << mismatchingDims.second << ")"; if (res == BroadcastableToResult::SourceTypeNotAVector) - return op.emitOpError("source type is not a vector"); + return emitOpError("source type is not a vector"); llvm_unreachable("unexpected vector.broadcast op error"); } @@ -1741,36 +1738,35 @@ p << " : " << op.v1().getType() << ", " << op.v2().getType(); } -static LogicalResult verify(ShuffleOp op) { - VectorType resultType = op.getVectorType(); - VectorType v1Type = op.getV1VectorType(); - VectorType v2Type = op.getV2VectorType(); +LogicalResult ShuffleOp::verify() { + VectorType resultType = getVectorType(); + VectorType v1Type = getV1VectorType(); + VectorType v2Type = getV2VectorType(); // Verify ranks. int64_t resRank = resultType.getRank(); int64_t v1Rank = v1Type.getRank(); int64_t v2Rank = v2Type.getRank(); if (resRank != v1Rank || v1Rank != v2Rank) - return op.emitOpError("rank mismatch"); + return emitOpError("rank mismatch"); // Verify all but leading dimension sizes. for (int64_t r = 1; r < v1Rank; ++r) { int64_t resDim = resultType.getDimSize(r); int64_t v1Dim = v1Type.getDimSize(r); int64_t v2Dim = v2Type.getDimSize(r); if (resDim != v1Dim || v1Dim != v2Dim) - return op.emitOpError("dimension mismatch"); + return emitOpError("dimension mismatch"); } // Verify mask length. - auto maskAttr = op.mask().getValue(); + auto maskAttr = mask().getValue(); int64_t maskLength = maskAttr.size(); if (maskLength != resultType.getDimSize(0)) - return op.emitOpError("mask length mismatch"); + return emitOpError("mask length mismatch"); // Verify all indices. int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0); for (const auto &en : llvm::enumerate(maskAttr)) { auto attr = en.value().dyn_cast(); if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) - return op.emitOpError("mask index #") - << (en.index() + 1) << " out of range"; + return emitOpError("mask index #") << (en.index() + 1) << " out of range"; } return success(); } @@ -1824,17 +1820,17 @@ result.addTypes(dest.getType()); } -static LogicalResult verify(InsertElementOp op) { - auto dstVectorType = op.getDestVectorType(); +LogicalResult InsertElementOp::verify() { + auto dstVectorType = getDestVectorType(); if (dstVectorType.getRank() == 0) { - if (op.position()) - return op.emitOpError("expected position to be empty with 0-D vector"); + if (position()) + return emitOpError("expected position to be empty with 0-D vector"); return success(); } if (dstVectorType.getRank() != 1) - return op.emitOpError("unexpected >1 vector rank"); - if (!op.position()) - return op.emitOpError("expected position for 1-D vector"); + return emitOpError("unexpected >1 vector rank"); + if (!position()) + return emitOpError("expected position for 1-D vector"); return success(); } @@ -1860,27 +1856,27 @@ build(builder, result, source, dest, positionConstants); } -static LogicalResult verify(InsertOp op) { - auto positionAttr = op.position().getValue(); - auto destVectorType = op.getDestVectorType(); +LogicalResult InsertOp::verify() { + auto positionAttr = position().getValue(); + auto destVectorType = getDestVectorType(); if (positionAttr.size() > static_cast(destVectorType.getRank())) - return op.emitOpError( + return emitOpError( "expected position attribute of rank smaller than dest vector rank"); - auto srcVectorType = op.getSourceType().dyn_cast(); + auto srcVectorType = getSourceType().dyn_cast(); if (srcVectorType && (static_cast(srcVectorType.getRank()) + positionAttr.size() != static_cast(destVectorType.getRank()))) - return op.emitOpError("expected position attribute rank + source rank to " + return emitOpError("expected position attribute rank + source rank to " "match dest vector rank"); if (!srcVectorType && (positionAttr.size() != static_cast(destVectorType.getRank()))) - return op.emitOpError( + return emitOpError( "expected position attribute rank to match the dest vector rank"); for (const auto &en : llvm::enumerate(positionAttr)) { auto attr = en.value().dyn_cast(); if (!attr || attr.getInt() < 0 || attr.getInt() >= destVectorType.getDimSize(en.index())) - return op.emitOpError("expected position attribute #") + return emitOpError("expected position attribute #") << (en.index() + 1) << " to be a non-negative integer smaller than the corresponding " "dest vector dimension"; @@ -1933,24 +1929,21 @@ InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids); } -static LogicalResult verify(InsertMapOp op) { - if (op.getSourceVectorType().getRank() != op.getResultType().getRank()) - return op.emitOpError( - "expected source and destination vectors of same rank"); +LogicalResult InsertMapOp::verify() { + if (getSourceVectorType().getRank() != getResultType().getRank()) + return emitOpError("expected source and destination vectors of same rank"); unsigned numId = 0; - for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) { - if (op.getResultType().getDimSize(i) % - op.getSourceVectorType().getDimSize(i) != + for (unsigned i = 0, e = getResultType().getRank(); i < e; i++) { + if (getResultType().getDimSize(i) % getSourceVectorType().getDimSize(i) != 0) - return op.emitOpError( + return emitOpError( "destination vector size must be a multiple of source vector size"); - if (op.getResultType().getDimSize(i) != - op.getSourceVectorType().getDimSize(i)) + if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i)) numId++; } - if (numId != op.ids().size()) - return op.emitOpError("expected number of ids must match the number of " - "dimensions distributed"); + if (numId != ids().size()) + return emitOpError("expected number of ids must match the number of " + "dimensions distributed"); return success(); } @@ -2062,19 +2055,18 @@ return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); } -static LogicalResult verify(InsertStridedSliceOp op) { - auto sourceVectorType = op.getSourceVectorType(); - auto destVectorType = op.getDestVectorType(); - auto offsets = op.offsets(); - auto strides = op.strides(); +LogicalResult InsertStridedSliceOp::verify() { + auto sourceVectorType = getSourceVectorType(); + auto destVectorType = getDestVectorType(); + auto offsets = offsetsAttr(); + auto strides = stridesAttr(); if (offsets.size() != static_cast(destVectorType.getRank())) - return op.emitOpError( + return emitOpError( "expected offsets of same size as destination vector rank"); if (strides.size() != static_cast(sourceVectorType.getRank())) - return op.emitOpError( - "expected strides of same size as source vector rank"); + return emitOpError("expected strides of same size as source vector rank"); if (sourceVectorType.getRank() > destVectorType.getRank()) - return op.emitOpError( + return emitOpError( "expected source rank to be smaller than destination rank"); auto sourceShape = sourceVectorType.getShape(); @@ -2084,13 +2076,14 @@ sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end()); auto offName = InsertStridedSliceOp::getOffsetsAttrName(); auto stridesName = InsertStridedSliceOp::getStridesAttrName(); - if (failed( - isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) || - failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName, + if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape, + offName)) || + failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, + stridesName, /*halfOpen=*/false)) || failed(isSumOfIntegerArrayAttrConfinedToShape( - op, offsets, - makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape, + *this, offsets, + makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape, offName, "source vector shape", /*halfOpen=*/false, /*min=*/1))) return failure(); @@ -2161,39 +2154,39 @@ parser.addTypeToList(resType, result.types)); } -static LogicalResult verify(OuterProductOp op) { - Type tRHS = op.getOperandTypeRHS(); - VectorType vLHS = op.getOperandVectorTypeLHS(), +LogicalResult OuterProductOp::verify() { + Type tRHS = getOperandTypeRHS(); + VectorType vLHS = getOperandVectorTypeLHS(), vRHS = tRHS.dyn_cast(), - vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType(); + vACC = getOperandVectorTypeACC(), vRES = getVectorType(); if (vLHS.getRank() != 1) - return op.emitOpError("expected 1-d vector for operand #1"); + return emitOpError("expected 1-d vector for operand #1"); if (vRHS) { // Proper OUTER operation. if (vRHS.getRank() != 1) - return op.emitOpError("expected 1-d vector for operand #2"); + return emitOpError("expected 1-d vector for operand #2"); if (vRES.getRank() != 2) - return op.emitOpError("expected 2-d vector result"); + return emitOpError("expected 2-d vector result"); if (vLHS.getDimSize(0) != vRES.getDimSize(0)) - return op.emitOpError("expected #1 operand dim to match result dim #1"); + return emitOpError("expected #1 operand dim to match result dim #1"); if (vRHS.getDimSize(0) != vRES.getDimSize(1)) - return op.emitOpError("expected #2 operand dim to match result dim #2"); + return emitOpError("expected #2 operand dim to match result dim #2"); } else { // An AXPY operation. if (vRES.getRank() != 1) - return op.emitOpError("expected 1-d vector result"); + return emitOpError("expected 1-d vector result"); if (vLHS.getDimSize(0) != vRES.getDimSize(0)) - return op.emitOpError("expected #1 operand dim to match result dim #1"); + return emitOpError("expected #1 operand dim to match result dim #1"); } if (vACC && vACC != vRES) - return op.emitOpError("expected operand #3 of same type as result type"); + return emitOpError("expected operand #3 of same type as result type"); // Verify supported combining kind. - if (!isSupportedCombiningKind(op.kind(), vRES.getElementType())) - return op.emitOpError("unsupported outerproduct type"); + if (!isSupportedCombiningKind(kind(), vRES.getElementType())) + return emitOpError("unsupported outerproduct type"); return success(); } @@ -2202,22 +2195,22 @@ // ReshapeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReshapeOp op) { +LogicalResult ReshapeOp::verify() { // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank. - auto inputVectorType = op.getInputVectorType(); - auto outputVectorType = op.getOutputVectorType(); - int64_t inputShapeRank = op.getNumInputShapeSizes(); - int64_t outputShapeRank = op.getNumOutputShapeSizes(); + auto inputVectorType = getInputVectorType(); + auto outputVectorType = getOutputVectorType(); + int64_t inputShapeRank = getNumInputShapeSizes(); + int64_t outputShapeRank = getNumOutputShapeSizes(); SmallVector fixedVectorSizes; - op.getFixedVectorSizes(fixedVectorSizes); + getFixedVectorSizes(fixedVectorSizes); int64_t numFixedVectorSizes = fixedVectorSizes.size(); if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) - return op.emitError("invalid input shape for vector type ") + return emitError("invalid input shape for vector type ") << inputVectorType; if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) - return op.emitError("invalid output shape for vector type ") + return emitError("invalid output shape for vector type ") << outputVectorType; // Verify that the 'fixedVectorSizes' match an input/output vector shape @@ -2226,7 +2219,7 @@ for (unsigned i = 0; i < numFixedVectorSizes; ++i) { unsigned index = inputVectorRank - numFixedVectorSizes - i; if (fixedVectorSizes[i] != inputVectorType.getShape()[index]) - return op.emitError("fixed vector size must match input vector for dim ") + return emitError("fixed vector size must match input vector for dim ") << i; } @@ -2234,7 +2227,7 @@ for (unsigned i = 0; i < numFixedVectorSizes; ++i) { unsigned index = outputVectorRank - numFixedVectorSizes - i; if (fixedVectorSizes[i] != outputVectorType.getShape()[index]) - return op.emitError("fixed vector size must match output vector for dim ") + return emitError("fixed vector size must match output vector for dim ") << i; } @@ -2243,18 +2236,18 @@ auto isDefByConstant = [](Value operand) { return isa_and_nonnull(operand.getDefiningOp()); }; - if (llvm::all_of(op.input_shape(), isDefByConstant) && - llvm::all_of(op.output_shape(), isDefByConstant)) { + if (llvm::all_of(input_shape(), isDefByConstant) && + llvm::all_of(output_shape(), isDefByConstant)) { int64_t numInputElements = 1; - for (auto operand : op.input_shape()) + for (auto operand : input_shape()) numInputElements *= cast(operand.getDefiningOp()).value(); int64_t numOutputElements = 1; - for (auto operand : op.output_shape()) + for (auto operand : output_shape()) numOutputElements *= cast(operand.getDefiningOp()).value(); if (numInputElements != numOutputElements) - return op.emitError("product of input and output shape sizes must match"); + return emitError("product of input and output shape sizes must match"); } return success(); } @@ -2301,42 +2294,37 @@ result.addAttribute(getStridesAttrName(), stridesAttr); } -static LogicalResult verify(ExtractStridedSliceOp op) { - auto type = op.getVectorType(); - auto offsets = op.offsets(); - auto sizes = op.sizes(); - auto strides = op.strides(); - if (offsets.size() != sizes.size() || offsets.size() != strides.size()) { - op.emitOpError( - "expected offsets, sizes and strides attributes of same size"); - return failure(); - } +LogicalResult ExtractStridedSliceOp::verify() { + auto type = getVectorType(); + auto offsets = offsetsAttr(); + auto sizes = sizesAttr(); + auto strides = stridesAttr(); + if (offsets.size() != sizes.size() || offsets.size() != strides.size()) + return emitOpError("expected offsets, sizes and strides attributes of same size"); auto shape = type.getShape(); - auto offName = ExtractStridedSliceOp::getOffsetsAttrName(); - auto sizesName = ExtractStridedSliceOp::getSizesAttrName(); - auto stridesName = ExtractStridedSliceOp::getStridesAttrName(); - if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) || - failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) || - failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape, + auto offName = getOffsetsAttrName(); + auto sizesName = getSizesAttrName(); + auto stridesName = getStridesAttrName(); + if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || + failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || + failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, stridesName)) || - failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) || - failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName, + failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || + failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, /*halfOpen=*/false, /*min=*/1)) || - failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName, + failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName, /*halfOpen=*/false)) || - failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape, + failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape, offName, sizesName, /*halfOpen=*/false))) return failure(); - auto resultType = inferStridedSliceOpResultType( - op.getVectorType(), op.offsets(), op.sizes(), op.strides()); - if (op.getResult().getType() != resultType) { - op.emitOpError("expected result type to be ") << resultType; - return failure(); - } + auto resultType = + inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides); + if (getResult().getType() != resultType) + return emitOpError("expected result type to be ") << resultType; return success(); } @@ -2828,44 +2816,43 @@ return parser.addTypeToList(vectorType, result.types); } -static LogicalResult verify(TransferReadOp op) { +LogicalResult TransferReadOp::verify() { // Consistency of elemental types in source and vector. - ShapedType shapedType = op.getShapedType(); - VectorType vectorType = op.getVectorType(); - VectorType maskType = op.getMaskType(); - auto paddingType = op.padding().getType(); - auto permutationMap = op.permutation_map(); + ShapedType shapedType = getShapedType(); + VectorType vectorType = getVectorType(); + VectorType maskType = getMaskType(); + auto paddingType = padding().getType(); + auto permutationMap = permutation_map(); auto sourceElementType = shapedType.getElementType(); - if (static_cast(op.indices().size()) != shapedType.getRank()) - return op.emitOpError("requires ") << shapedType.getRank() << " indices"; + if (static_cast(indices().size()) != shapedType.getRank()) + return emitOpError("requires ") << shapedType.getRank() << " indices"; - if (failed( - verifyTransferOp(cast(op.getOperation()), - shapedType, vectorType, maskType, permutationMap, - op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) + if (failed(verifyTransferOp(cast(getOperation()), + shapedType, vectorType, maskType, permutationMap, + in_bounds() ? *in_bounds() : ArrayAttr()))) return failure(); if (auto sourceVectorElementType = sourceElementType.dyn_cast()) { // Source has vector element type. // Check that 'sourceVectorElementType' and 'paddingType' types match. if (sourceVectorElementType != paddingType) - return op.emitOpError( + return emitOpError( "requires source element type and padding type to match."); } else { // Check that 'paddingType' is valid to store in a vector type. if (!VectorType::isValidElementType(paddingType)) - return op.emitOpError("requires valid padding vector elemental type"); + return emitOpError("requires valid padding vector elemental type"); // Check that padding type and vector element types match. if (paddingType != sourceElementType) - return op.emitOpError( + return emitOpError( "requires formal padding and source of the same elemental type"); } return verifyPermutationMap(permutationMap, - [&op](Twine t) { return op.emitOpError(t); }); + [&](Twine t) { return emitOpError(t); }); } /// This is a common class used for patterns of the form @@ -3208,29 +3195,28 @@ p << " : " << op.getVectorType() << ", " << op.getShapedType(); } -static LogicalResult verify(TransferWriteOp op) { +LogicalResult TransferWriteOp::verify() { // Consistency of elemental types in shape and vector. - ShapedType shapedType = op.getShapedType(); - VectorType vectorType = op.getVectorType(); - VectorType maskType = op.getMaskType(); - auto permutationMap = op.permutation_map(); + ShapedType shapedType = getShapedType(); + VectorType vectorType = getVectorType(); + VectorType maskType = getMaskType(); + auto permutationMap = permutation_map(); - if (llvm::size(op.indices()) != shapedType.getRank()) - return op.emitOpError("requires ") << shapedType.getRank() << " indices"; + if (llvm::size(indices()) != shapedType.getRank()) + return emitOpError("requires ") << shapedType.getRank() << " indices"; // We do not allow broadcast dimensions on TransferWriteOps for the moment, // as the semantics is unclear. This can be revisited later if necessary. - if (op.hasBroadcastDim()) - return op.emitOpError("should not have broadcast dimensions"); + if (hasBroadcastDim()) + return emitOpError("should not have broadcast dimensions"); - if (failed( - verifyTransferOp(cast(op.getOperation()), - shapedType, vectorType, maskType, permutationMap, - op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) + if (failed(verifyTransferOp(cast(getOperation()), + shapedType, vectorType, maskType, permutationMap, + in_bounds() ? *in_bounds() : ArrayAttr()))) return failure(); return verifyPermutationMap(permutationMap, - [&op](Twine t) { return op.emitOpError(t); }); + [&](Twine t) { return emitOpError(t); }); } /// Fold: @@ -3514,25 +3500,25 @@ return success(); } -static LogicalResult verify(vector::LoadOp op) { - VectorType resVecTy = op.getVectorType(); - MemRefType memRefTy = op.getMemRefType(); +LogicalResult vector::LoadOp::verify() { + VectorType resVecTy = getVectorType(); + MemRefType memRefTy = getMemRefType(); - if (failed(verifyLoadStoreMemRefLayout(op, memRefTy))) + if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) return failure(); // Checks for vector memrefs. Type memElemTy = memRefTy.getElementType(); if (auto memVecTy = memElemTy.dyn_cast()) { if (memVecTy != resVecTy) - return op.emitOpError("base memref and result vector types should match"); + return emitOpError("base memref and result vector types should match"); memElemTy = memVecTy.getElementType(); } if (resVecTy.getElementType() != memElemTy) - return op.emitOpError("base and result element types should match"); - if (llvm::size(op.indices()) != memRefTy.getRank()) - return op.emitOpError("requires ") << memRefTy.getRank() << " indices"; + return emitOpError("base and result element types should match"); + if (llvm::size(indices()) != memRefTy.getRank()) + return emitOpError("requires ") << memRefTy.getRank() << " indices"; return success(); } @@ -3546,26 +3532,26 @@ // StoreOp //===----------------------------------------------------------------------===// -static LogicalResult verify(vector::StoreOp op) { - VectorType valueVecTy = op.getVectorType(); - MemRefType memRefTy = op.getMemRefType(); +LogicalResult vector::StoreOp::verify() { + VectorType valueVecTy = getVectorType(); + MemRefType memRefTy = getMemRefType(); - if (failed(verifyLoadStoreMemRefLayout(op, memRefTy))) + if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) return failure(); // Checks for vector memrefs. Type memElemTy = memRefTy.getElementType(); if (auto memVecTy = memElemTy.dyn_cast()) { if (memVecTy != valueVecTy) - return op.emitOpError( + return emitOpError( "base memref and valueToStore vector types should match"); memElemTy = memVecTy.getElementType(); } if (valueVecTy.getElementType() != memElemTy) - return op.emitOpError("base and valueToStore element type should match"); - if (llvm::size(op.indices()) != memRefTy.getRank()) - return op.emitOpError("requires ") << memRefTy.getRank() << " indices"; + return emitOpError("base and valueToStore element type should match"); + if (llvm::size(indices()) != memRefTy.getRank()) + return emitOpError("requires ") << memRefTy.getRank() << " indices"; return success(); } @@ -3578,20 +3564,20 @@ // MaskedLoadOp //===----------------------------------------------------------------------===// -static LogicalResult verify(MaskedLoadOp op) { - VectorType maskVType = op.getMaskVectorType(); - VectorType passVType = op.getPassThruVectorType(); - VectorType resVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult MaskedLoadOp::verify() { + VectorType maskVType = getMaskVectorType(); + VectorType passVType = getPassThruVectorType(); + VectorType resVType = getVectorType(); + MemRefType memType = getMemRefType(); if (resVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and result element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and result element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected result dim to match mask dim"); + return emitOpError("expected result dim to match mask dim"); if (resVType != passVType) - return op.emitOpError("expected pass_thru of same type as result type"); + return emitOpError("expected pass_thru of same type as result type"); return success(); } @@ -3632,17 +3618,17 @@ // MaskedStoreOp //===----------------------------------------------------------------------===// -static LogicalResult verify(MaskedStoreOp op) { - VectorType maskVType = op.getMaskVectorType(); - VectorType valueVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult MaskedStoreOp::verify() { + VectorType maskVType = getMaskVectorType(); + VectorType valueVType = getVectorType(); + MemRefType memType = getMemRefType(); if (valueVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and valueToStore element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and valueToStore element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected valueToStore dim to match mask dim"); + return emitOpError("expected valueToStore dim to match mask dim"); return success(); } @@ -3682,22 +3668,22 @@ // GatherOp //===----------------------------------------------------------------------===// -static LogicalResult verify(GatherOp op) { - VectorType indVType = op.getIndexVectorType(); - VectorType maskVType = op.getMaskVectorType(); - VectorType resVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult GatherOp::verify() { + VectorType indVType = getIndexVectorType(); + VectorType maskVType = getMaskVectorType(); + VectorType resVType = getVectorType(); + MemRefType memType = getMemRefType(); if (resVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and result element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and result element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != indVType.getDimSize(0)) - return op.emitOpError("expected result dim to match indices dim"); + return emitOpError("expected result dim to match indices dim"); if (resVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected result dim to match mask dim"); - if (resVType != op.getPassThruVectorType()) - return op.emitOpError("expected pass_thru of same type as result type"); + return emitOpError("expected result dim to match mask dim"); + if (resVType != getPassThruVectorType()) + return emitOpError("expected pass_thru of same type as result type"); return success(); } @@ -3730,20 +3716,20 @@ // ScatterOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ScatterOp op) { - VectorType indVType = op.getIndexVectorType(); - VectorType maskVType = op.getMaskVectorType(); - VectorType valueVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult ScatterOp::verify() { + VectorType indVType = getIndexVectorType(); + VectorType maskVType = getMaskVectorType(); + VectorType valueVType = getVectorType(); + MemRefType memType = getMemRefType(); if (valueVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and valueToStore element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and valueToStore element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != indVType.getDimSize(0)) - return op.emitOpError("expected valueToStore dim to match indices dim"); + return emitOpError("expected valueToStore dim to match indices dim"); if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected valueToStore dim to match mask dim"); + return emitOpError("expected valueToStore dim to match mask dim"); return success(); } @@ -3776,20 +3762,20 @@ // ExpandLoadOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ExpandLoadOp op) { - VectorType maskVType = op.getMaskVectorType(); - VectorType passVType = op.getPassThruVectorType(); - VectorType resVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult ExpandLoadOp::verify() { + VectorType maskVType = getMaskVectorType(); + VectorType passVType = getPassThruVectorType(); + VectorType resVType = getVectorType(); + MemRefType memType = getMemRefType(); if (resVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and result element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and result element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected result dim to match mask dim"); + return emitOpError("expected result dim to match mask dim"); if (resVType != passVType) - return op.emitOpError("expected pass_thru of same type as result type"); + return emitOpError("expected pass_thru of same type as result type"); return success(); } @@ -3824,17 +3810,17 @@ // CompressStoreOp //===----------------------------------------------------------------------===// -static LogicalResult verify(CompressStoreOp op) { - VectorType maskVType = op.getMaskVectorType(); - VectorType valueVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult CompressStoreOp::verify() { + VectorType maskVType = getMaskVectorType(); + VectorType valueVType = getVectorType(); + MemRefType memType = getMemRefType(); if (valueVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and valueToStore element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and valueToStore element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected valueToStore dim to match mask dim"); + return emitOpError("expected valueToStore dim to match mask dim"); return success(); } @@ -3930,13 +3916,13 @@ return success(); } -static LogicalResult verify(ShapeCastOp op) { - auto sourceVectorType = op.source().getType().dyn_cast_or_null(); - auto resultVectorType = op.result().getType().dyn_cast_or_null(); +LogicalResult ShapeCastOp::verify() { + auto sourceVectorType = source().getType().dyn_cast_or_null(); + auto resultVectorType = result().getType().dyn_cast_or_null(); // Check if source/result are of vector type. if (sourceVectorType && resultVectorType) - return verifyVectorShapeCast(op, sourceVectorType, resultVectorType); + return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType); return success(); } @@ -4005,16 +3991,16 @@ // VectorBitCastOp //===----------------------------------------------------------------------===// -static LogicalResult verify(BitCastOp op) { - auto sourceVectorType = op.getSourceVectorType(); - auto resultVectorType = op.getResultVectorType(); +LogicalResult BitCastOp::verify() { + auto sourceVectorType = getSourceVectorType(); + auto resultVectorType = getResultVectorType(); for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) { if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i)) - return op.emitOpError("dimension size mismatch at: ") << i; + return emitOpError("dimension size mismatch at: ") << i; } - DataLayout dataLayout = DataLayout::closest(op); + DataLayout dataLayout = DataLayout::closest(*this); auto sourceElementBits = dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()); auto resultElementBits = @@ -4022,11 +4008,11 @@ if (sourceVectorType.getRank() == 0) { if (sourceElementBits != resultElementBits) - return op.emitOpError("source/result bitwidth of the 0-D vector element " + return emitOpError("source/result bitwidth of the 0-D vector element " "types must be equal"); } else if (sourceElementBits * sourceVectorType.getShape().back() != resultElementBits * resultVectorType.getShape().back()) { - return op.emitOpError( + return emitOpError( "source/result bitwidth of the minor 1-D vectors must be equal"); } @@ -4096,26 +4082,25 @@ memRefType.getMemorySpace())); } -static LogicalResult verify(TypeCastOp op) { - MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType()); +LogicalResult TypeCastOp::verify() { + MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType()); if (!canonicalType.getLayout().isIdentity()) - return op.emitOpError( - "expects operand to be a memref with identity layout"); - if (!op.getResultMemRefType().getLayout().isIdentity()) - return op.emitOpError("expects result to be a memref with identity layout"); - if (op.getResultMemRefType().getMemorySpace() != - op.getMemRefType().getMemorySpace()) - return op.emitOpError("expects result in same memory space"); - - auto sourceType = op.getMemRefType(); - auto resultType = op.getResultMemRefType(); + return emitOpError("expects operand to be a memref with identity layout"); + if (!getResultMemRefType().getLayout().isIdentity()) + return emitOpError("expects result to be a memref with identity layout"); + if (getResultMemRefType().getMemorySpace() != + getMemRefType().getMemorySpace()) + return emitOpError("expects result in same memory space"); + + auto sourceType = getMemRefType(); + auto resultType = getResultMemRefType(); if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != getElementTypeOrSelf(getElementTypeOrSelf(resultType))) - return op.emitOpError( + return emitOpError( "expects result and operand with same underlying scalar type: ") << resultType; if (extractShape(sourceType) != extractShape(resultType)) - return op.emitOpError( + return emitOpError( "expects concatenated result and operand shapes to be equal: ") << resultType; return success(); @@ -4154,27 +4139,27 @@ return vector(); } -static LogicalResult verify(vector::TransposeOp op) { - VectorType vectorType = op.getVectorType(); - VectorType resultType = op.getResultType(); +LogicalResult vector::TransposeOp::verify() { + VectorType vectorType = getVectorType(); + VectorType resultType = getResultType(); int64_t rank = resultType.getRank(); if (vectorType.getRank() != rank) - return op.emitOpError("vector result rank mismatch: ") << rank; + return emitOpError("vector result rank mismatch: ") << rank; // Verify transposition array. - auto transpAttr = op.transp().getValue(); + auto transpAttr = transp().getValue(); int64_t size = transpAttr.size(); if (rank != size) - return op.emitOpError("transposition length mismatch: ") << size; + return emitOpError("transposition length mismatch: ") << size; SmallVector seen(rank, false); for (const auto &ta : llvm::enumerate(transpAttr)) { int64_t i = ta.value().cast().getInt(); if (i < 0 || i >= rank) - return op.emitOpError("transposition index out of range: ") << i; + return emitOpError("transposition index out of range: ") << i; if (seen[i]) - return op.emitOpError("duplicate position index: ") << i; + return emitOpError("duplicate position index: ") << i; seen[i] = true; if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i)) - return op.emitOpError("dimension size mismatch at: ") << i; + return emitOpError("dimension size mismatch at: ") << i; } return success(); } @@ -4236,31 +4221,30 @@ // ConstantMaskOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ConstantMaskOp &op) { - auto resultType = op.getResult().getType().cast(); +LogicalResult ConstantMaskOp::verify() { + auto resultType = getResult().getType().cast(); // Check the corner case of 0-D vectors first. if (resultType.getRank() == 0) { - if (op.mask_dim_sizes().size() != 1) - return op->emitError("array attr must have length 1 for 0-D vectors"); - auto dim = op.mask_dim_sizes()[0].cast().getInt(); + if (mask_dim_sizes().size() != 1) + return emitError("array attr must have length 1 for 0-D vectors"); + auto dim = mask_dim_sizes()[0].cast().getInt(); if (dim != 0 && dim != 1) - return op->emitError( - "mask dim size must be either 0 or 1 for 0-D vectors"); + return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); return success(); } // Verify that array attr size matches the rank of the vector result. - if (static_cast(op.mask_dim_sizes().size()) != resultType.getRank()) - return op.emitOpError( + if (static_cast(mask_dim_sizes().size()) != resultType.getRank()) + return emitOpError( "must specify array attr of size equal vector result rank"); // Verify that each array attr element is in bounds of corresponding vector // result dimension size. auto resultShape = resultType.getShape(); SmallVector maskDimSizes; - for (const auto &it : llvm::enumerate(op.mask_dim_sizes())) { + for (const auto &it : llvm::enumerate(mask_dim_sizes())) { int64_t attrValue = it.value().cast().getInt(); if (attrValue < 0 || attrValue > resultShape[it.index()]) - return op.emitOpError( + return emitOpError( "array attr of size out of bounds of vector result dimension size"); maskDimSizes.push_back(attrValue); } @@ -4269,8 +4253,8 @@ bool anyZeros = llvm::is_contained(maskDimSizes, 0); bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); if (anyZeros && !allZeros) - return op.emitOpError("expected all mask dim sizes to be zeros, " - "as a result of conjunction with zero mask dim"); + return emitOpError("expected all mask dim sizes to be zeros, " + "as a result of conjunction with zero mask dim"); return success(); } @@ -4278,16 +4262,16 @@ // CreateMaskOp //===----------------------------------------------------------------------===// -static LogicalResult verify(CreateMaskOp op) { - auto vectorType = op.getResult().getType().cast(); +LogicalResult CreateMaskOp::verify() { + auto vectorType = getResult().getType().cast(); // Verify that an operand was specified for each result vector each dimension. if (vectorType.getRank() == 0) { - if (op->getNumOperands() != 1) - return op.emitOpError( + if (getNumOperands() != 1) + return emitOpError( "must specify exactly one operand for 0-D create_mask"); - } else if (op.getNumOperands() != - op.getResult().getType().cast().getRank()) { - return op.emitOpError( + } else if (getNumOperands() != + getResult().getType().cast().getRank()) { + return emitOpError( "must specify an operand for each result vector dimension"); } return success(); @@ -4342,20 +4326,20 @@ // ScanOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ScanOp op) { - VectorType srcType = op.getSourceType(); - VectorType initialType = op.getInitialValueType(); +LogicalResult ScanOp::verify() { + VectorType srcType = getSourceType(); + VectorType initialType = getInitialValueType(); // Check reduction dimension < rank. int64_t srcRank = srcType.getRank(); - int64_t reductionDim = op.reduction_dim(); + int64_t reductionDim = reduction_dim(); if (reductionDim >= srcRank) - return op.emitOpError("reduction dimension ") + return emitOpError("reduction dimension ") << reductionDim << " has to be less than " << srcRank; // Check that rank(initial_value) = rank(src) - 1. int64_t initialValueRank = initialType.getRank(); if (initialValueRank != srcRank - 1) - return op.emitOpError("initial value rank ") + return emitOpError("initial value rank ") << initialValueRank << " has to be equal to " << srcRank - 1; // Check shapes of initial value and src. @@ -4370,7 +4354,7 @@ [](std::tuple s) { return std::get<0>(s) != std::get<1>(s); })) { - return op.emitOpError("incompatible input/initial value shapes"); + return emitOpError("incompatible input/initial value shapes"); } return success(); diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -28,17 +28,15 @@ >(); } -static LogicalResult verify(x86vector::MaskCompressOp op) { - if (op.src() && op.constant_src()) - return emitError(op.getLoc(), "cannot use both src and constant_src"); +LogicalResult x86vector::MaskCompressOp::verify() { + if (src() && constant_src()) + return emitError("cannot use both src and constant_src"); - if (op.src() && (op.src().getType() != op.dst().getType())) - return emitError(op.getLoc(), - "failed to verify that src and dst have same type"); + if (src() && (src().getType() != dst().getType())) + return emitError("failed to verify that src and dst have same type"); - if (op.constant_src() && (op.constant_src()->getType() != op.dst().getType())) + if (constant_src() && (constant_src()->getType() != dst().getType())) return emitError( - op.getLoc(), "failed to verify that constant_src and dst have same type"); return success(); diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -128,19 +128,19 @@ p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); } -static LogicalResult verify(FuncOp op) { +LogicalResult FuncOp::verify() { // If this function is external there is nothing to do. - if (op.isExternal()) + if (isExternal()) return success(); // Verify that the argument list of the function and the arg list of the entry // block line up. The trait already verified that the number of arguments is // the same between the signature and the block. - auto fnInputTypes = op.getType().getInputs(); - Block &entryBlock = op.front(); + auto fnInputTypes = getType().getInputs(); + Block &entryBlock = front(); for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) - return op.emitOpError("type of entry block argument #") + return emitOpError("type of entry block argument #") << i << '(' << entryBlock.getArgument(i).getType() << ") must match the type of the corresponding argument in " << "function signature(" << fnInputTypes[i] << ')'; @@ -245,28 +245,28 @@ return {}; } -static LogicalResult verify(ModuleOp op) { +LogicalResult ModuleOp::verify() { // Check that none of the attributes are non-dialect attributes, except for // the symbol related attributes. - for (auto attr : op->getAttrs()) { + for (auto attr : (*this)->getAttrs()) { if (!attr.getName().strref().contains('.') && !llvm::is_contained( ArrayRef{mlir::SymbolTable::getSymbolAttrName(), mlir::SymbolTable::getVisibilityAttrName()}, attr.getName().strref())) - return op.emitOpError() << "can only contain attributes with " - "dialect-prefixed names, found: '" - << attr.getName().getValue() << "'"; + return emitOpError() << "can only contain attributes with " + "dialect-prefixed names, found: '" + << attr.getName().getValue() << "'"; } // Check that there is at most one data layout spec attribute. StringRef layoutSpecAttrName; DataLayoutSpecInterface layoutSpec; - for (const NamedAttribute &na : op->getAttrs()) { + for (const NamedAttribute &na : (*this)->getAttrs()) { if (auto spec = na.getValue().dyn_cast()) { if (layoutSpec) { InFlightDiagnostic diag = - op.emitOpError() << "expects at most one data layout attribute"; + emitOpError() << "expects at most one data layout attribute"; diag.attachNote() << "'" << layoutSpecAttrName << "' is a data layout attribute"; diag.attachNote() << "'" << na.getName().getValue() diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -90,7 +90,7 @@ %cst = arith.constant 1 : index %value = memref.alloc() : memref<10xf32> -// expected-error@+1 {{async attribute cannot appear with asyncOperand}} +// expected-error@+1 {{async attribute cannot appear with asyncOperand}} acc.update async(%cst: index) host(%value: memref<10xf32>) attributes {async} // ----- diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -40,10 +40,10 @@ OpBuilder<(ins CArg<"int", "0">:$integer)>]; let parser = [{ foo }]; let printer = [{ bar }]; - let verifier = [{ baz }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Display a graph for debugging purposes.