Index: mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h =================================================================== --- mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h +++ mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" Index: mlir/include/mlir/IR/Value.h =================================================================== --- mlir/include/mlir/IR/Value.h +++ mlir/include/mlir/IR/Value.h @@ -419,6 +419,53 @@ return cast(this)->getResultNumber(); } +/// Wrapper around an mlir::Value or its derivatives where the type is assumed +/// and checked if the value is present +/// TypedValue can contain a null/empty value. +template +struct TypedValueImpl : ValueClass { + /// Return the known Type + Ty getType() { return Value::getType().template cast(); } + void setType(Ty ty) { + assert(ty.template isa()); + Value::setType(ty); + } + + TypedValueImpl(ValueClass val) : ValueClass(val) { + assert(!val || val.getType().template isa()); + } + TypedValueImpl &operator=(const ValueClass &other) { + assert(!other || other.getType().template isa()); + return static_cast(Value::operator=(other)); + } + + /// Stay a TypedValue when casting to an other kind of Value + template + TypedValueImpl dyn_cast() const { + return ValueClass::template dyn_cast(); + } + template + TypedValueImpl dyn_cast_or_null() const { + return ValueClass::template dyn_cast_or_null(); + } + template + TypedValueImpl cast() const { + return ValueClass::template cast(); + } +}; + +/// If Ty is mlir::Type This will select ValueClass instead of having a wrapper +/// around it. This helps resolving many ambiguous conversions issue +template +struct TypedValueSelector { + using type = TypedValueImpl; +}; + +template +struct TypedValueSelector { + using type = ValueClass; +}; + } // namespace detail /// This is a value defined by a result of an operation. @@ -459,6 +506,9 @@ return ::llvm::hash_value(arg.getImpl()); } +template +using TypedValue = typename detail::TypedValueSelector::type; + } // namespace mlir namespace llvm { Index: mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp =================================================================== --- mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -784,7 +784,7 @@ launchOp.getKernelName().getValue(), loc, rewriter); auto function = moduleGetFunctionCallBuilder.create( loc, rewriter, {module.getResult(), kernelName}); - auto zero = rewriter.create(loc, llvmInt32Type, 0); + Value zero = rewriter.create(loc, llvmInt32Type, 0); Value stream = adaptor.asyncDependencies().empty() ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult() Index: mlir/lib/Dialect/SCF/Utils/Utils.cpp =================================================================== --- mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -428,7 +428,7 @@ // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. OpBuilder boundsBuilder(forOp); auto loc = forOp.getLoc(); - auto step = forOp.getStep(); + Value step = forOp.getStep(); Value upperBoundUnrolled; Value stepUnrolled; bool generateEpilogueLoop = true; Index: mlir/lib/Dialect/Shape/IR/Shape.cpp =================================================================== --- mlir/lib/Dialect/Shape/IR/Shape.cpp +++ mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -778,7 +778,7 @@ PatternRewriter &rewriter) const override { // Canonicalize operands. bool anyChange = false; - auto canonicalizeOperand = [&](Value operand) { + auto canonicalizeOperand = [&](Value operand) -> Value { if (auto castOp = operand.getDefiningOp()) { // Only eliminate the cast if it holds no shape information. bool isInformationLoosingCast = Index: mlir/test/lib/Dialect/Test/TestDialect.h =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.h +++ mlir/test/lib/Dialect/Test/TestDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_TESTDIALECT_H #define MLIR_TESTDIALECT_H +#include "TestTypes.h" #include "TestAttributes.h" #include "TestInterfaces.h" #include "mlir/Dialect/DLTI/DLTI.h" Index: mlir/test/mlir-tblgen/op-attribute.td =================================================================== --- mlir/test/mlir-tblgen/op-attribute.td +++ mlir/test/mlir-tblgen/op-attribute.td @@ -453,8 +453,8 @@ // DECL: static void build({{.*}}, bool dv_bool_attr, ::mlir::BlockRange succ) // DEF-LABEL: MixOperandsAndAttrs definitions -// DEF-DAG: ::mlir::Value MixOperandsAndAttrs::operand() -// DEF-DAG: ::mlir::Value MixOperandsAndAttrs::otherArg() +// DEF-DAG: ::mlir::TypedValue<::mlir::FloatType> MixOperandsAndAttrs::operand() +// DEF-DAG: ::mlir::TypedValue<::mlir::FloatType> MixOperandsAndAttrs::otherArg() // DEF-DAG: void MixOperandsAndAttrs::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::FloatAttr attr, ::mlir::Value operand, ::mlir::FloatAttr otherAttr, ::mlir::Value otherArg) // DEF-DAG: ::llvm::APFloat MixOperandsAndAttrs::attr() // DEF-DAG: ::llvm::APFloat MixOperandsAndAttrs::otherAttr() Index: mlir/test/mlir-tblgen/op-decl-and-defs.td =================================================================== --- mlir/test/mlir-tblgen/op-decl-and-defs.td +++ mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -78,7 +78,7 @@ // CHECK: return ::llvm::StringLiteral("test.a_op"); // CHECK: } // CHECK: ::mlir::Operation::operand_range getODSOperands(unsigned index); -// CHECK: ::mlir::Value getA(); +// CHECK: ::mlir::TypedValue<::mlir::IntegerType> getA(); // CHECK: ::mlir::Operation::operand_range getB(); // CHECK: ::mlir::MutableOperandRange getAMutable(); // CHECK: ::mlir::MutableOperandRange getBMutable(); @@ -169,7 +169,7 @@ // CHECK-LABEL: NS::EOp declarations // CHECK: ::mlir::Value getA(); // CHECK: ::mlir::MutableOperandRange getAMutable(); -// CHECK: ::mlir::Value getB(); +// CHECK: ::mlir::TypedValue<::mlir::FloatType> getB(); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type b, /*optional*/::mlir::Value a) // Check that all types match constraint results in generating builder. Index: mlir/test/mlir-tblgen/op-operand.td =================================================================== --- mlir/test/mlir-tblgen/op-operand.td +++ mlir/test/mlir-tblgen/op-operand.td @@ -54,7 +54,7 @@ // CHECK-LABEL: ::mlir::Operation::operand_range OpD::input1 // CHECK-NEXT: return getODSOperands(0); -// CHECK-LABEL: ::mlir::Value OpD::input2 +// CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpD::input2 // CHECK-NEXT: return *getODSOperands(1).begin(); // CHECK-LABEL: OpD::build Index: mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1152,6 +1152,22 @@ } } +static std::string generateTypeForGetter(bool isAdaptor, + const NamedTypeConstraint &value) { + std::string str = "::mlir::Value"; + /// If the CPPClassName is not a fully qualified type. Uses of types + /// across Dialect fail because they are not in the correct namespace. So we + /// dont generate TypedValue unless the type is fully qualified. Adaptor will + /// have values that are not from the type of their operation and this is + /// expected, so we dont generate TypedValue for Adaptor + if (!isAdaptor && + StringRef(value.constraint.getCPPClassName()).startswith("::")) + str = llvm::formatv("::mlir::TypedValue<{0}>", + value.constraint.getCPPClassName()) + .str(); + return str; +} + // Generates the named operand getter methods for the given Operator `op` and // puts them in `opClass`. Uses `rangeType` as the return type of getters that // return a range of operands (individual operands are `Value ` and each @@ -1216,7 +1232,7 @@ continue; for (StringRef name : op.getGetterNames(operand.name)) { if (operand.isOptional()) { - m = opClass.addMethod("::mlir::Value", name); + m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name); ERROR_IF_PRUNED(m, name, op); m->body() << " auto operands = getODSOperands(" << i << ");\n" << " return operands.empty() ? ::mlir::Value() : " @@ -1242,7 +1258,7 @@ ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSOperands(" << i << ");"; } else { - m = opClass.addMethod("::mlir::Value", name); + m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name); ERROR_IF_PRUNED(m, name, op); m->body() << " return *getODSOperands(" << i << ").begin();"; } @@ -1365,7 +1381,7 @@ continue; for (StringRef name : op.getGetterNames(result.name)) { if (result.isOptional()) { - m = opClass.addMethod("::mlir::Value", name); + m = opClass.addMethod(generateTypeForGetter(false, result), name); ERROR_IF_PRUNED(m, name, op); m->body() << " auto results = getODSResults(" << i << ");\n"