diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -34,26 +34,27 @@ SmallVector scalarOperands; SmallVector newIndexingMaps; SmallVector newOperands; - for (auto it : llvm::enumerate(llvm::zip(genericOp.getInputIndexingMaps(), - genericOp.getInputTensors()))) { - AffineMap map = std::get<0>(it.value()); - if (map.isConstant()) { - scalarOperands.emplace_back(it.index()); + for (OpOperand *opOperand : genericOp.getInputOperands()) { + AffineMap map = genericOp.getTiedIndexingMap(opOperand); + if (genericOp.isInputTensor(opOperand) && map.isConstant()) { + scalarOperands.emplace_back(opOperand->getOperandNumber()); } else { newIndexingMaps.emplace_back(map); - newOperands.emplace_back(std::get<1>(it.value())); + newOperands.emplace_back(opOperand->get()); } } if (scalarOperands.empty()) return failure(); - newIndexingMaps.append(genericOp.getOutputIndexingMaps()); + for (OpOperand *opOperand : genericOp.getOutputOperands()) + newIndexingMaps.emplace_back(genericOp.getTiedIndexingMap(opOperand)); Location loc = genericOp->getLoc(); + SmallVector outputOperands = genericOp.getOutputOperands(); auto newOp = rewriter.create( - loc, genericOp->getResultTypes(), newOperands, - genericOp.getOutputTensors(), newIndexingMaps, + loc, genericOp->getResultTypes(), newOperands, outputOperands, + newIndexingMaps, llvm::to_vector<4>( genericOp.iterator_types().template getAsValueRange())); rewriter.cloneRegionBefore(genericOp.region(), newOp.region(), @@ -64,14 +65,15 @@ rewriter.setInsertionPointToStart(body); for (auto idx : llvm::reverse(scalarOperands)) { - Value operand = genericOp.getInput(idx); - AffineMap map = genericOp.getInputIndexingMap(idx); + OpOperand *opOperand = genericOp.getInputOperand(idx); + AffineMap map = genericOp.getTiedIndexingMap(opOperand); SmallVector indices = map.getConstantResults(); SmallVector indicesValues; for (auto idx : indices) indicesValues.emplace_back(rewriter.create(loc, idx)); - operand = rewriter.create(loc, operand, indicesValues); - body->getArgument(idx).replaceAllUsesWith(operand); + Value extractedValue = rewriter.create( + loc, opOperand->get(), indicesValues); + body->getArgument(idx).replaceAllUsesWith(extractedValue); body->eraseArgument(idx); }