diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -11,6 +11,7 @@ include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/IR/OpAsmInterface.td" diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -10,6 +10,7 @@ #define COMPLEX_OPS include "mlir/Dialect/Complex/IR/ComplexBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" class Complex_Op traits = []> @@ -143,10 +144,6 @@ let arguments = (ins Complex:$lhs, Complex:$rhs); let results = (outs I1:$result); - let builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{ - build($_builder, $_state, $_builder.getI1Type(), lhs, rhs); - }]>]; let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)"; } @@ -292,10 +289,6 @@ let arguments = (ins Complex:$lhs, Complex:$rhs); let results = (outs I1:$result); - let builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{ - build($_builder, $_state, $_builder.getI1Type(), lhs, rhs); - }]>]; let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)"; } diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -23,6 +23,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -18,6 +18,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -22,6 +22,7 @@ #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" def FMFnnan : BitEnumAttrCase<"nnan", 0x1>; @@ -622,11 +623,6 @@ let arguments = (ins LLVM_ScalarOrVectorOf:$condition, LLVM_Type:$trueValue, LLVM_Type:$falseValue); let results = (outs LLVM_Type:$res); - let builders = [ - OpBuilder<(ins "Value":$condition, "Value":$lhs, "Value":$rhs), - [{ - build($_builder, $_state, lhs.getType(), condition, lhs, rhs); - }]>]; let assemblyFormat = "operands attr-dict `:` type($condition) `,` type($res)"; } def LLVM_FreezeOp : LLVM_Op<"freeze", [SameOperandsAndResultType]> { diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -10,6 +10,7 @@ #define MATH_OPS include "mlir/Dialect/Math/IR/MathBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/QuantOps.td --- a/mlir/include/mlir/Dialect/Quant/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/QuantOps.td @@ -14,6 +14,7 @@ #define DIALECT_QUANT_QUANT_OPS_ include "mlir/Dialect/Quant/QuantOpsBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -15,6 +15,7 @@ #define MLIR_DIALECT_SPIRV_IR_ARITHMETIC_OPS include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" class SPV_ArithmeticBinaryOp]; - let assemblyFormat = [{ operands attr-dict `:` type($condition) `,` type($result) }]; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td @@ -208,8 +208,6 @@ SPV_Bool:$result ); - let builders = [OpBuilder<(ins "spirv::Scope")>]; - let assemblyFormat = "$execution_scope attr-dict `:` type($result)"; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/Support/PointerLikeTypeTraits.h" diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -11,6 +11,7 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -22,6 +22,7 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -19,6 +19,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" @@ -687,12 +688,6 @@ let results = (outs Index); let verifier = ?; - let builders = [ - OpBuilder<(ins "Value":$tensor), [{ - auto indexType = $_builder.getIndexType(); - build($_builder, $_state, indexType, tensor); - }]>]; - let hasFolder = 1; let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)"; } @@ -775,13 +770,6 @@ AnyType:$false_value); let results = (outs AnyType:$result); - let builders = [ - OpBuilder<(ins "Value":$condition, "Value":$trueValue, - "Value":$falseValue), [{ - $_state.addOperands({condition, trueValue, falseValue}); - $_state.addTypes(trueValue.getType()); - }]>]; - let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -13,6 +13,7 @@ #ifndef X86VECTOR_OPS #define X86VECTOR_OPS +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h --- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h @@ -17,6 +17,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h.inc" diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1948,8 +1948,6 @@ def SameOperandsShape : NativeOpTrait<"SameOperandsShape">; // Op has same operand and result shape. def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">; -// Op has the same operand and result type. -def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">; // Op has the same element type (or type itself, if scalar) for all operands. def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">; // Op has the same operand and result element type (or type itself, if scalar). diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -178,4 +178,8 @@ ]; } +// Op has the same operand and result type. +// TODO: Change from hard coded to utilizing type inference trait. +def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">; + #endif // MLIR_INFERTYPEOPINTERFACE diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2395,12 +2395,6 @@ // spv.GroupNonUniformElectOp //===----------------------------------------------------------------------===// -void spirv::GroupNonUniformElectOp::build(OpBuilder &builder, - OperationState &state, - spirv::Scope scope) { - build(builder, state, builder.getI1Type(), scope); -} - static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) { spirv::Scope scope = groupOp.execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) @@ -2849,11 +2843,6 @@ // spv.Select //===----------------------------------------------------------------------===// -void spirv::SelectOp::build(OpBuilder &builder, OperationState &state, - Value cond, Value trueValue, Value falseValue) { - build(builder, state, trueValue.getType(), cond, trueValue, falseValue); -} - static LogicalResult verify(spirv::SelectOp op) { if (auto conditionTy = op.condition().getType().dyn_cast()) { auto resultVectorTy = op.result().getType().dyn_cast(); diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" using namespace mlir; diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -25,7 +25,11 @@ // ----- func @bit_field_insert_invalid_insert_type(%base: vector<3xi32>, %insert: vector<2xi32>, %offset: i32, %count: i16) -> vector<3xi32> { - // expected-error @+1 {{all of {base, insert, result} have same type}} + // TODO: expand post change in verification order. This is currently only + // verifying that the type verification is failing but not the specific error + // message. In final state the error should refer to mismatch in base and + // insert. + // expected-error @+1 {{type}} %0 = "spv.BitFieldInsert" (%base, %insert, %offset, %count) : (vector<3xi32>, vector<2xi32>, i32, i16) -> vector<3xi32> spv.ReturnValue %0 : vector<3xi32> } @@ -55,7 +59,7 @@ // ----- func @bit_field_u_extract_invalid_result_type(%base: vector<3xi32>, %offset: i32, %count: i16) -> vector<4xi32> { - // expected-error @+1 {{failed to verify that all of {base, result} have same type}} + // expected-error @+1 {{inferred type(s) 'vector<3xi32>' are incompatible with return type(s) of operation 'vector<4xi32>'}} %0 = "spv.BitFieldUExtract" (%base, %offset, %count) : (vector<3xi32>, i32, i16) -> vector<4xi32> spv.ReturnValue %0 : vector<4xi32> } diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -270,7 +270,11 @@ func @select_op(%arg1: vector<4xi1>) -> () { %0 = spv.Constant dense<[2.0, 3.0, 4.0]> : vector<3xf32> %1 = spv.Constant dense<[5, 6, 7]> : vector<3xi32> - // expected-error @+1 {{all of {true_value, false_value, result} have same type}} + // TODO: expand post change in verification order. This is currently only + // verifying that the type verification is failing but not the specific error + // message. In final state the error should refer to mismatch in true_value and + // false_value. + // expected-error @+1 {{type}} %2 = "spv.Select"(%arg1, %1, %0) : (vector<4xi1>, vector<3xi32>, vector<3xf32>) -> vector<3xi32> return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -137,7 +137,11 @@ func @func_with_ops(i1, i32, i64) { ^bb0(%cond : i1, %t : i32, %f : i64): - // expected-error@+1 {{all of {true_value, false_value, result} have same type}} + // TODO: expand post change in verification order. This is currently only + // verifying that the type verification is failing but not the specific error + // message. In final state the error should refer to mismatch in true_value and + // false_value. + // expected-error@+1 {{type}} %r = "std.select"(%cond, %t, %f) : (i1, i32, i64) -> i32 } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1500,6 +1500,7 @@ srcs = ["include/mlir/Dialect/X86Vector/X86Vector.td"], includes = ["include"], deps = [ + ":InferTypeOpInterfaceTdFiles", ":LLVMOpsTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -1548,6 +1549,7 @@ includes = ["include"], deps = [ ":IR", + ":InferTypeOpInterface", ":LLVMDialect", ":SideEffectInterfaces", ":X86VectorIncGen", @@ -1688,6 +1690,7 @@ ], includes = ["include"], deps = [ + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -1788,6 +1791,7 @@ deps = [ ":ArithmeticDialect", ":IR", + ":InferTypeOpInterface", ":SideEffectInterfaces", ":SparseTensorAttrDefsIncGen", ":SparseTensorOpsIncGen", @@ -1856,6 +1860,7 @@ ":CallInterfacesTdFiles", ":CastInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ":VectorInterfacesTdFiles", @@ -2519,6 +2524,7 @@ ":CommonFolders", ":ControlFlowInterfaces", ":IR", + ":InferTypeOpInterface", ":SideEffectInterfaces", ":StandardOpsIncGen", ":Support", @@ -2750,6 +2756,7 @@ ":ControlFlowInterfaces", ":DataLayoutInterfaces", ":IR", + ":InferTypeOpInterface", ":LLVMDialectAttributesIncGen", ":LLVMDialectInterfaceIncGen", ":LLVMOpsIncGen", @@ -2915,6 +2922,7 @@ ":GPUBaseIncGen", ":GPUOpsIncGen", ":IR", + ":InferTypeOpInterface", ":LLVMDialect", ":MemRefDialect", ":SideEffectInterfaces", @@ -3011,6 +3019,7 @@ includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -3669,6 +3678,7 @@ deps = [ ":CallInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -3819,6 +3829,7 @@ ":CommonFolders", ":ControlFlowInterfaces", ":IR", + ":InferTypeOpInterface", ":Parser", ":Pass", ":SPIRVAttrUtilsGen", @@ -6037,6 +6048,7 @@ ], includes = ["include"], deps = [ + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -6939,6 +6951,7 @@ ], includes = ["include"], deps = [ + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -7073,6 +7086,7 @@ includes = ["include"], deps = [ ":CastInterfacesTdFiles", + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ":VectorInterfacesTdFiles", @@ -7218,6 +7232,7 @@ ], includes = ["include"], deps = [ + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ":VectorInterfacesTdFiles",