diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -701,24 +701,27 @@ })); SmallVector expandedOpOperands; + expandedOpOperands.reserve(genericOp.getNumInputs()); for (OpOperand *opOperand : genericOp.getInputOperands()) { if (opOperand == fusableOpOperand) { expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() : collapsingReshapeOp.src()); continue; } - AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); - RankedTensorType expandedOperandType = - getExpandedType(opOperand->get().getType().cast(), - indexingMap, expansionInfo); - if (expandedOperandType != opOperand->get().getType()) { - // Reshape the operand to get the right type. - SmallVector reassociation = - getReassociationForExpansion(indexingMap, expansionInfo); - expandedOpOperands.push_back(rewriter.create( - genericOp.getLoc(), expandedOperandType, opOperand->get(), - reassociation)); - continue; + if (genericOp.isInputTensor(opOperand)) { + AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + RankedTensorType expandedOperandType = + getExpandedType(opOperand->get().getType().cast(), + indexingMap, expansionInfo); + if (expandedOperandType != opOperand->get().getType()) { + // Reshape the operand to get the right type. + SmallVector reassociation = + getReassociationForExpansion(indexingMap, expansionInfo); + expandedOpOperands.push_back(rewriter.create( + genericOp.getLoc(), expandedOperandType, opOperand->get(), + reassociation)); + continue; + } } expandedOpOperands.push_back(opOperand->get()); } @@ -1035,7 +1038,7 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : genericOp.getInputOperands()) { + for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { TensorCollapseShapeOp reshapeOp = opOperand->get().getDefiningOp(); if (!reshapeOp)