Index: mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h =================================================================== --- mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h +++ mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" Index: mlir/include/mlir/IR/Value.h =================================================================== --- mlir/include/mlir/IR/Value.h +++ mlir/include/mlir/IR/Value.h @@ -419,6 +419,27 @@ return cast(this)->getResultNumber(); } +/// TypedValue is a Value with a statically know type. +/// TypedValue can be null/empty +template +struct TypedValue : Value { + /// Return the known Type + Ty getType() { return Value::getType().template cast(); } + void setType(mlir::Type ty) { + assert(ty.template isa()); + Value::setType(ty); + } + + TypedValue(Value val) : Value(val) { + assert(!val || val.getType().template isa()); + } + TypedValue &operator=(const Value &other) { + assert(!other || other.getType().template isa()); + Value::operator=(other); + return *this; + } +}; + } // namespace detail /// This is a value defined by a result of an operation. @@ -459,6 +480,12 @@ return ::llvm::hash_value(arg.getImpl()); } +template +/// If Ty is mlir::Type this will select `Value` instead of having a wrapper +/// around it. This helps resolve ambiguous conversion issues. +using TypedValue = std::conditional_t, + mlir::Value, detail::TypedValue>; + } // namespace mlir namespace llvm { Index: mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp =================================================================== --- mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -784,7 +784,7 @@ launchOp.getKernelName().getValue(), loc, rewriter); auto function = moduleGetFunctionCallBuilder.create( loc, rewriter, {module.getResult(), kernelName}); - auto zero = rewriter.create(loc, llvmInt32Type, 0); + Value zero = rewriter.create(loc, llvmInt32Type, 0); Value stream = adaptor.asyncDependencies().empty() ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult() Index: mlir/lib/Dialect/SCF/Utils/Utils.cpp =================================================================== --- mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -428,7 +428,7 @@ // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. OpBuilder boundsBuilder(forOp); auto loc = forOp.getLoc(); - auto step = forOp.getStep(); + Value step = forOp.getStep(); Value upperBoundUnrolled; Value stepUnrolled; bool generateEpilogueLoop = true; Index: mlir/lib/Dialect/Shape/IR/Shape.cpp =================================================================== --- mlir/lib/Dialect/Shape/IR/Shape.cpp +++ mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -778,7 +778,7 @@ PatternRewriter &rewriter) const override { // Canonicalize operands. bool anyChange = false; - auto canonicalizeOperand = [&](Value operand) { + auto canonicalizeOperand = [&](Value operand) -> Value { if (auto castOp = operand.getDefiningOp()) { // Only eliminate the cast if it holds no shape information. bool isInformationLoosingCast = Index: mlir/test/lib/Dialect/Test/TestDialect.h =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.h +++ mlir/test/lib/Dialect/Test/TestDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_TESTDIALECT_H #define MLIR_TESTDIALECT_H +#include "TestTypes.h" #include "TestAttributes.h" #include "TestInterfaces.h" #include "mlir/Dialect/DLTI/DLTI.h" Index: mlir/test/mlir-tblgen/op-attribute.td =================================================================== --- mlir/test/mlir-tblgen/op-attribute.td +++ mlir/test/mlir-tblgen/op-attribute.td @@ -453,8 +453,8 @@ // DECL: static void build({{.*}}, bool dv_bool_attr, ::mlir::BlockRange succ) // DEF-LABEL: MixOperandsAndAttrs definitions -// DEF-DAG: ::mlir::Value MixOperandsAndAttrs::operand() -// DEF-DAG: ::mlir::Value MixOperandsAndAttrs::otherArg() +// DEF-DAG: ::mlir::TypedValue<::mlir::FloatType> MixOperandsAndAttrs::operand() +// DEF-DAG: ::mlir::TypedValue<::mlir::FloatType> MixOperandsAndAttrs::otherArg() // DEF-DAG: void MixOperandsAndAttrs::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::FloatAttr attr, ::mlir::Value operand, ::mlir::FloatAttr otherAttr, ::mlir::Value otherArg) // DEF-DAG: ::llvm::APFloat MixOperandsAndAttrs::attr() // DEF-DAG: ::llvm::APFloat MixOperandsAndAttrs::otherAttr() Index: mlir/test/mlir-tblgen/op-decl-and-defs.td =================================================================== --- mlir/test/mlir-tblgen/op-decl-and-defs.td +++ mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -78,12 +78,12 @@ // CHECK: return ::llvm::StringLiteral("test.a_op"); // CHECK: } // CHECK: ::mlir::Operation::operand_range getODSOperands(unsigned index); -// CHECK: ::mlir::Value getA(); +// CHECK: ::mlir::TypedValue<::mlir::IntegerType> getA(); // CHECK: ::mlir::Operation::operand_range getB(); // CHECK: ::mlir::MutableOperandRange getAMutable(); // CHECK: ::mlir::MutableOperandRange getBMutable(); // CHECK: ::mlir::Operation::result_range getODSResults(unsigned index); -// CHECK: ::mlir::Value getR(); +// CHECK: ::mlir::TypedValue<::mlir::IntegerType> getR(); // CHECK: ::mlir::Region &getSomeRegion(); // CHECK: ::mlir::MutableArrayRef<::mlir::Region> getSomeRegions(); // CHECK: ::mlir::IntegerAttr getAttr1Attr() @@ -169,7 +169,7 @@ // CHECK-LABEL: NS::EOp declarations // CHECK: ::mlir::Value getA(); // CHECK: ::mlir::MutableOperandRange getAMutable(); -// CHECK: ::mlir::Value getB(); +// CHECK: ::mlir::TypedValue<::mlir::FloatType> getB(); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type b, /*optional*/::mlir::Value a) // Check that all types match constraint results in generating builder. Index: mlir/test/mlir-tblgen/op-operand.td =================================================================== --- mlir/test/mlir-tblgen/op-operand.td +++ mlir/test/mlir-tblgen/op-operand.td @@ -54,7 +54,7 @@ // CHECK-LABEL: ::mlir::Operation::operand_range OpD::input1 // CHECK-NEXT: return getODSOperands(0); -// CHECK-LABEL: ::mlir::Value OpD::input2 +// CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpD::input2 // CHECK-NEXT: return *getODSOperands(1).begin(); // CHECK-LABEL: OpD::build Index: mlir/test/mlir-tblgen/op-result.td =================================================================== --- mlir/test/mlir-tblgen/op-result.td +++ mlir/test/mlir-tblgen/op-result.td @@ -102,7 +102,7 @@ // CHECK-LABEL: ::mlir::Operation::result_range OpI::output1 // CHECK-NEXT: return getODSResults(0); -// CHECK-LABEL: ::mlir::Value OpI::output2 +// CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpI::output2 // CHECK-NEXT: return *getODSResults(1).begin(); // CHECK-LABEL: OpI::build Index: mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1152,6 +1152,25 @@ } } +static std::string generateTypeForGetter(bool isAdaptor, + const NamedTypeConstraint &value) { + std::string str = "::mlir::Value"; + /// If the CPPClassName is not a fully qualified type. Uses of types + /// across Dialect fail because they are not in the correct namespace. So we + /// dont generate TypedValue unless the type is fully qualified. + /// getCPPClassName doesn't return the fully qualified path for + /// `mlir::pdl::OperationType` see + /// https://github.com/llvm/llvm-project/issues/57279. + /// Adaptor will have values that are not from the type of their operation and + /// this is expected, so we dont generate TypedValue for Adaptor + if (!isAdaptor && value.constraint.getCPPClassName() != "::mlir::Type" && + StringRef(value.constraint.getCPPClassName()).startswith("::")) + str = llvm::formatv("::mlir::TypedValue<{0}>", + value.constraint.getCPPClassName()) + .str(); + return str; +} + // Generates the named operand getter methods for the given Operator `op` and // puts them in `opClass`. Uses `rangeType` as the return type of getters that // return a range of operands (individual operands are `Value ` and each @@ -1216,7 +1235,7 @@ continue; for (StringRef name : op.getGetterNames(operand.name)) { if (operand.isOptional()) { - m = opClass.addMethod("::mlir::Value", name); + m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name); ERROR_IF_PRUNED(m, name, op); m->body() << " auto operands = getODSOperands(" << i << ");\n" << " return operands.empty() ? ::mlir::Value() : " @@ -1242,7 +1261,7 @@ ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSOperands(" << i << ");"; } else { - m = opClass.addMethod("::mlir::Value", name); + m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name); ERROR_IF_PRUNED(m, name, op); m->body() << " return *getODSOperands(" << i << ").begin();"; } @@ -1365,7 +1384,8 @@ continue; for (StringRef name : op.getGetterNames(result.name)) { if (result.isOptional()) { - m = opClass.addMethod("::mlir::Value", name); + m = opClass.addMethod( + generateTypeForGetter(/*isAdaptor=*/false, result), name); ERROR_IF_PRUNED(m, name, op); m->body() << " auto results = getODSResults(" << i << ");\n" @@ -1375,7 +1395,8 @@ ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSResults(" << i << ");"; } else { - m = opClass.addMethod("::mlir::Value", name); + m = opClass.addMethod( + generateTypeForGetter(/*isAdaptor=*/false, result), name); ERROR_IF_PRUNED(m, name, op); m->body() << " return *getODSResults(" << i << ").begin();"; }