diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -392,9 +392,11 @@ /// Note: This function does not erase the operation on a successful fold. LogicalResult OpBuilder::tryFold(Operation *op, SmallVectorImpl &results) { - results.reserve(op->getNumResults()); + ResultRange opResults = op->getResults(); + + results.reserve(opResults.size()); auto cleanupFailure = [&] { - results.assign(op->result_begin(), op->result_end()); + results.assign(opResults.begin(), opResults.end()); return failure(); }; @@ -405,7 +407,7 @@ // Check to see if any operands to the operation is constant and whether // the operation knows how to constant fold itself. SmallVector constOperands(op->getNumOperands()); - for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + for (unsigned i = 0, e = constOperands.size(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constOperands[i])); // Try to fold the operation. @@ -419,9 +421,14 @@ // Populate the results with the folded results. Dialect *dialect = op->getDialect(); - for (auto &it : llvm::enumerate(foldResults)) { + for (auto it : llvm::zip(foldResults, opResults.getTypes())) { + Type expectedType = std::get<1>(it); + // Normal values get pushed back directly. - if (auto value = it.value().dyn_cast()) { + if (auto value = std::get<0>(it).dyn_cast()) { + if (value.getType() != expectedType) + return cleanupFailure(); + results.push_back(value); continue; } @@ -431,9 +438,9 @@ return cleanupFailure(); // Ask the dialect to materialize a constant operation for this value. - Attribute attr = it.value().get(); - auto *constOp = dialect->materializeConstant( - cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc()); + Attribute attr = std::get<0>(it).get(); + auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType, + op->getLoc()); if (!constOp) { // Erase any generated constants. for (Operation *cst : generatedConstants) diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -307,3 +307,13 @@ } } + +// ----- + +// The "passthrough_fold" folder will naively return its operand, but we don't +// want to fold here because of the type mismatch. +func @typemismatch(%arg: f32) -> i32 { + // expected-remark@+1 {{op 'test.passthrough_fold' is not legalizable}} + %0 = "test.passthrough_fold"(%arg) : (f32) -> (i32) + "test.return"(%0) : (i32) -> () +}