diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -249,7 +249,7 @@ }; struct UnitExtentReplacementInfo { - RankedTensorType type; + Type type; AffineMap indexMap; ArrayAttr reassociation; }; @@ -271,10 +271,10 @@ AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); ArrayRef shape = genericOp.getShape(opOperand); ArrayRef exprs = indexingMap.getResults(); - SmallVector reassociations; - SmallVector reassociationMaps; - SmallVector newIndexExprs; - SmallVector newShape; + SmallVector reassociations; + SmallVector reassociationMaps; + SmallVector newIndexExprs; + SmallVector newShape; int64_t origRank = genericOp.getRank(opOperand); AffineExpr zeroExpr = getAffineConstantExpr(0, context); @@ -282,7 +282,7 @@ return shape[dim] == 1 && exprs[dim] == zeroExpr; }; - unsigned dim = 0; + int64_t dim = 0; // Fold dimensions that are unit-extent at the beginning of the tensor. while (dim < origRank && isUnitExtent(dim)) reassociations.push_back(getAffineDimExpr(dim++, context)); @@ -300,12 +300,16 @@ reassociations.clear(); ++dim; } - UnitExtentReplacementInfo info = { - RankedTensorType::get(newShape, - getElementTypeOrSelf(opOperand->get().getType())), - AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(), - newIndexExprs, context), - ArrayAttr::get(context, reassociationMaps)}; + // Compute the tensor or scalar replacement type. + Type elementType = getElementTypeOrSelf(opOperand->get().getType()); + Type replacementType = elementType == opOperand->get().getType() + ? elementType + : RankedTensorType::get(newShape, elementType); + UnitExtentReplacementInfo info = {replacementType, + AffineMap::get(indexingMap.getNumDims(), + indexingMap.getNumSymbols(), + newIndexExprs, context), + ArrayAttr::get(context, reassociationMaps)}; return info; } @@ -331,13 +335,14 @@ MLIRContext *context = rewriter.getContext(); Location loc = genericOp.getLoc(); - SmallVector newIndexingMaps; - SmallVector reassociationMaps; - SmallVector newInputOutputTypes; + SmallVector newIndexingMaps; + SmallVector reassociationMaps; + SmallVector newInputOutputTypes; bool doCanonicalization = false; for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context); + UnitExtentReplacementInfo replacementInfo = + replaceUnitExtents(genericOp, opOperand, context); reassociationMaps.push_back(replacementInfo.reassociation); newIndexingMaps.push_back(replacementInfo.indexMap); newInputOutputTypes.push_back(replacementInfo.type);