diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -375,7 +375,7 @@ ```c++ /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. -void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } ``` At this point, each of the necessary Toy operations provide a mechanism by which diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -237,7 +237,7 @@ /// Infer the output shape of the AddOp, this is required by the shape inference /// interface. -void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // CastOp @@ -245,7 +245,7 @@ /// Infer the output shape of the CastOp, this is required by the shape /// inference interface. -void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } /// Returns true if the given set of input and result types are compatible with /// this cast operation. This is required by the `CastOpInterface` to verify @@ -349,7 +349,7 @@ /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. -void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // ReturnOp diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -237,7 +237,7 @@ /// Infer the output shape of the AddOp, this is required by the shape inference /// interface. -void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // CastOp @@ -245,7 +245,7 @@ /// Infer the output shape of the CastOp, this is required by the shape /// inference interface. -void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } /// Returns true if the given set of input and result types are compatible with /// this cast operation. This is required by the `CastOpInterface` to verify @@ -349,7 +349,7 @@ /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. -void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // ReturnOp diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -237,7 +237,7 @@ /// Infer the output shape of the AddOp, this is required by the shape inference /// interface. -void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // CastOp @@ -245,7 +245,7 @@ /// Infer the output shape of the CastOp, this is required by the shape /// inference interface. -void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } /// Returns true if the given set of input and result types are compatible with /// this cast operation. This is required by the `CastOpInterface` to verify @@ -349,7 +349,7 @@ /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. -void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // ReturnOp diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -243,7 +243,9 @@ /// Infer the output shape of the ConstantOp, this is required by the shape /// inference interface. -void ConstantOp::inferShapes() { getResult().setType(getValue().getType()); } +void ConstantOp::inferShapes() { + getResult().setType(cast(getValue().getType())); +} //===----------------------------------------------------------------------===// // AddOp @@ -264,7 +266,7 @@ /// Infer the output shape of the AddOp, this is required by the shape inference /// interface. -void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // CastOp @@ -272,7 +274,7 @@ /// Infer the output shape of the CastOp, this is required by the shape /// inference interface. -void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } /// Returns true if the given set of input and result types are compatible with /// this cast operation. This is required by the `CastOpInterface` to verify @@ -376,7 +378,7 @@ /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. -void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // ReturnOp diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -632,7 +632,7 @@ : public TraitBase::Impl> { public: TypedValue getResult() { - return this->getOperation()->getResult(0); + return cast>(this->getOperation()->getResult(0)); } /// If the operation returns a single value, then the Op can be implicitly diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -427,21 +427,13 @@ /// TypedValue can be null/empty template struct TypedValue : Value { + using Value::Value; + + static bool classof(Value value) { return llvm::isa(value.getType()); } + /// Return the known Type Ty getType() { return Value::getType().template cast(); } - void setType(mlir::Type ty) { - assert(ty.template isa()); - Value::setType(ty); - } - - TypedValue(Value val) : Value(val) { - assert(!val || val.getType().template isa()); - } - TypedValue &operator=(const Value &other) { - assert(!other || other.getType().template isa()); - Value::operator=(other); - return *this; - } + void setType(Ty ty) { Value::setType(ty); } }; } // namespace detail diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -152,6 +152,9 @@ /// Get the name of the method. StringRef getName() const { return methodName; } + /// Get the return type of the method + StringRef getReturnType() const { return returnType; } + /// Get the number of parameters. unsigned getNumParameters() const { return parameters.getNumParameters(); } @@ -344,6 +347,9 @@ /// Returns the name of this method. StringRef getName() const { return methodSignature.getName(); } + /// Returns the return type of this method + StringRef getReturnType() const { return methodSignature.getReturnType(); } + /// Returns if this method makes the `other` method redundant. bool makesRedundant(const Method &other) const { return methodSignature.makesRedundant(other.methodSignature); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1884,7 +1884,7 @@ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { llvm::SmallVector indices; for (unsigned int i = 0; i < inputTy.getRank(); i++) { - auto index = + Value index = rewriter.create(nestedLoc, i).getResult(); if (i == axis) { auto one = rewriter.create(nestedLoc, 1); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1033,7 +1033,7 @@ auto vec = getDataVector(xferOp); auto xferVecType = xferOp.getVectorType(); int64_t dimSize = xferVecType.getShape()[0]; - auto source = xferOp.getSource(); // memref or tensor to be written to. + Value source = xferOp.getSource(); // memref or tensor to be written to. auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); // Generate fully unrolled loop of transfer ops. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1056,7 +1056,7 @@ // %t = sparse_tensor.ConvertOp %tmp RankedTensorType cooTp = getUnorderedCOOFromTypeWithOrdering(dstTp, encDst.getDimOrdering()); - auto cooBuffer = + Value cooBuffer = rewriter.create(loc, cooTp, dynSizesArray).getResult(); Value c0 = constantIndex(rewriter, loc, 0); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -173,7 +173,8 @@ class SparseTensorSpecifier { public: - explicit SparseTensorSpecifier(Value specifier) : specifier(specifier) {} + explicit SparseTensorSpecifier(Value specifier) + : specifier(cast>(specifier)) {} // Undef value for dimension sizes, all zero value for memory sizes. static Value getInitValue(OpBuilder &builder, Location loc, diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -43,7 +43,7 @@ // CHECK-NEXT: return getODSOperands(0); // CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpD::getInput2 -// CHECK-NEXT: return *getODSOperands(1).begin(); +// CHECK-NEXT: return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); // CHECK-LABEL: OpD::build // CHECK-NEXT: odsState.addOperands(input1); diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -100,7 +100,7 @@ // CHECK-NEXT: return getODSResults(0); // CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpI::getOutput2 -// CHECK-NEXT: return *getODSResults(1).begin(); +// CHECK-NEXT: return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSResults(1).begin()); // CHECK-LABEL: OpI::build // CHECK-NEXT: odsState.addTypes(output1); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1337,10 +1337,12 @@ : generateTypeForGetter(operand), name); ERROR_IF_PRUNED(m, name, op); - m->body().indent() << formatv( - "auto operands = getODSOperands({0});\n" - "return operands.empty() ? {1}{{} : *operands.begin();", - i, rangeElementType); + m->body().indent() << formatv("auto operands = getODSOperands({0});\n" + "return operands.empty() ? {1}{{} : ", + i, m->getReturnType()); + if (!isGenericAdaptorBase) + m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType()); + m->body() << "(*operands.begin());"; } else if (operand.isVariadicOfVariadic()) { std::string segmentAttr = op.getGetterName( operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); @@ -1366,7 +1368,10 @@ : generateTypeForGetter(operand), name); ERROR_IF_PRUNED(m, name, op); - m->body() << " return *getODSOperands(" << i << ").begin();"; + m->body().indent() << "return "; + if (!isGenericAdaptorBase) + m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType()); + m->body() << llvm::formatv("(*getODSOperands({0}).begin());", i); } } } @@ -1489,9 +1494,11 @@ if (result.isOptional()) { m = opClass.addMethod(generateTypeForGetter(result), name); ERROR_IF_PRUNED(m, name, op); - m->body() - << " auto results = getODSResults(" << i << ");\n" - << " return results.empty() ? ::mlir::Value() : *results.begin();"; + m->body() << " auto results = getODSResults(" << i << ");\n" + << llvm::formatv(" return results.empty()" + " ? {0}()" + " : ::llvm::cast<{0}>(*results.begin());", + m->getReturnType()); } else if (result.isVariadic()) { m = opClass.addMethod("::mlir::Operation::result_range", name); ERROR_IF_PRUNED(m, name, op); @@ -1499,7 +1506,9 @@ } else { m = opClass.addMethod(generateTypeForGetter(result), name); ERROR_IF_PRUNED(m, name, op); - m->body() << " return *getODSResults(" << i << ").begin();"; + m->body() << llvm::formatv( + " return ::llvm::cast<{0}>(*getODSResults({1}).begin());", + m->getReturnType(), i); } } } diff --git a/mlir/unittests/IR/IRMapping.cpp b/mlir/unittests/IR/IRMapping.cpp --- a/mlir/unittests/IR/IRMapping.cpp +++ b/mlir/unittests/IR/IRMapping.cpp @@ -32,7 +32,7 @@ IRMapping mapping; mapping.map(i64Val, f64Val); - TypedValue typedI64Val = i64Val; + auto typedI64Val = cast>(i64Val); EXPECT_EQ(mapping.lookup(typedI64Val), f64Val); }