diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -183,9 +183,7 @@ AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); if (!invertedMap) return failure(); - SmallVector dims; - for (ShapedType shapedType : genericOp.getShapedOperandTypes()) - dims.append(shapedType.getShape().begin(), shapedType.getShape().end()); + SmallVector dims = genericOp.getStaticShape(); // Find all the reduction iterators. Those need some special consideration // (see below). @@ -267,17 +265,18 @@ /// - modified index map that can be used to access the replaced result/operand /// - the reassociation that converts from the original tensor type to the /// modified tensor type. -static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, - RankedTensorType type, +static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp, + OpOperand *opOperand, MLIRContext *context) { - ArrayRef shape = type.getShape(); - ArrayRef exprs = indexMap.getResults(); + AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + ArrayRef shape = genericOp.getShape(opOperand); + ArrayRef exprs = indexingMap.getResults(); SmallVector reassociations; SmallVector reassociationMaps; SmallVector newIndexExprs; SmallVector newShape; - int64_t origRank = type.getRank(); + int64_t origRank = genericOp.getRank(opOperand); AffineExpr zeroExpr = getAffineConstantExpr(0, context); auto isUnitExtent = [&](int64_t dim) -> bool { return shape[dim] == 1 && exprs[dim] == zeroExpr; @@ -302,8 +301,9 @@ ++dim; } UnitExtentReplacementInfo info = { - RankedTensorType::get(newShape, type.getElementType()), - AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), + RankedTensorType::get(newShape, + getElementTypeOrSelf(opOperand->get().getType())), + AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(), newIndexExprs, context), ArrayAttr::get(context, reassociationMaps)}; return info; @@ -335,15 +335,13 @@ SmallVector reassociationMaps; SmallVector newInputOutputTypes; bool doCanonicalization = false; - for (auto it : llvm::zip(genericOp.getIndexingMaps(), - genericOp.getShapedOperandTypes())) { - auto replacementInfo = replaceUnitExtents( - std::get<0>(it), std::get<1>(it).template cast(), - context); + + for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { + auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context); reassociationMaps.push_back(replacementInfo.reassociation); newIndexingMaps.push_back(replacementInfo.indexMap); newInputOutputTypes.push_back(replacementInfo.type); - doCanonicalization |= replacementInfo.type != std::get<1>(it); + doCanonicalization |= replacementInfo.type != opOperand->get().getType(); } // If the indexing maps of the result operation are not invertible (i.e. not