diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -886,9 +886,9 @@ /// Perform a replacement of one iter OpOperand of an scf.for to the /// `replacement` value which is expected to be the source of a tensor.cast. /// tensor.cast ops are inserted inside the block to account for the type cast. -static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter, - OpOperand &operand, - Value replacement) { +static SmallVector +replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand, + Value replacement) { Type oldType = operand.get().getType(), newType = replacement.getType(); assert(oldType.isa() && newType.isa() && "expected ranked tensor types"); @@ -897,8 +897,8 @@ ForOp forOp = cast(operand.getOwner()); assert(operand.getOperandNumber() >= forOp.getNumControlOperands() && "expected an iter OpOperand"); - if (operand.get().getType() == replacement.getType()) - return forOp; + assert(operand.get().getType() != replacement.getType() && + "Expected a different type"); SmallVector newIterOperands; for (OpOperand &opOperand : forOp.getIterOpOperands()) { if (opOperand.getOperandNumber() == operand.getOperandNumber()) { @@ -949,7 +949,7 @@ newResults[yieldIdx] = rewriter.create( newForOp.getLoc(), oldType, newResults[yieldIdx]); - return newForOp; + return newResults; } /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing @@ -986,7 +986,8 @@ for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) { OpOperand &iterOpOperand = std::get<0>(it); auto incomingCast = iterOpOperand.get().getDefiningOp(); - if (!incomingCast) + if (!incomingCast || + incomingCast.getSource().getType() == incomingCast.getType()) continue; // If the dest type of the cast does not preserve static information in // the source type. @@ -998,18 +999,9 @@ continue; // Create a new ForOp with that iter operand replaced. - auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand, - incomingCast.getSource()); - - // Insert outgoing cast and use it to replace the corresponding result. - rewriter.setInsertionPointAfter(newForOp); - SmallVector replacements = newForOp.getResults(); - unsigned returnIdx = - iterOpOperand.getOperandNumber() - op.getNumControlOperands(); - replacements[returnIdx] = rewriter.create( - op.getLoc(), incomingCast.getDest().getType(), - replacements[returnIdx]); - rewriter.replaceOp(op, replacements); + rewriter.replaceOp( + op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand, + incomingCast.getSource())); return success(); } return failure();