diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -267,9 +267,9 @@ /// - modified index map that can be used to access the replaced result/operand /// - the reassociation that converts from the original tensor type to the /// modified tensor type. -static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp, - OpOperand *opOperand, - MLIRContext *context) { +static llvm::Optional +replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand, + MLIRContext *context) { AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); ArrayRef shape = genericOp.getShape(opOperand); ArrayRef exprs = indexingMap.getResults(); @@ -284,6 +284,14 @@ return shape[dim] == 1 && exprs[dim] == zeroExpr; }; + // Early return for memrefs with affine maps to represent that we will always + // leave them unchanged. + Type actualType = opOperand->get().getType(); + if (auto memref = actualType.dyn_cast()) { + if (!memref.getAffineMaps().empty()) + return llvm::None; + } + int64_t dim = 0; // Fold dimensions that are unit-extent at the beginning of the tensor. while (dim < origRank && isUnitExtent(dim)) @@ -302,8 +310,8 @@ reassociations.clear(); ++dim; } + // Compute the tensor or scalar replacement type. - Type actualType = opOperand->get().getType(); Type elementType = getElementTypeOrSelf(opOperand->get()); Type replacementType; if (elementType == opOperand->get().getType()) { @@ -311,8 +319,6 @@ } else if (actualType.isa()) { replacementType = RankedTensorType::get(newShape, elementType); } else if (actualType.isa()) { - assert(actualType.cast().getAffineMaps().empty() && - "unsupported strided memrefs"); replacementType = MemRefType::get(newShape, elementType); } assert(replacementType && "unsupported shaped type"); @@ -390,12 +396,28 @@ SmallVector newInputOutputTypes; bool doCanonicalization = false; for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - UnitExtentReplacementInfo replacementInfo = - replaceUnitExtents(genericOp, opOperand, context); - reassociationMaps.push_back(replacementInfo.reassociation); - newIndexingMaps.push_back(replacementInfo.indexMap); - newInputOutputTypes.push_back(replacementInfo.type); - doCanonicalization |= replacementInfo.type != opOperand->get().getType(); + auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context); + if (replacementInfo) { + reassociationMaps.push_back(replacementInfo->reassociation); + newIndexingMaps.push_back(replacementInfo->indexMap); + newInputOutputTypes.push_back(replacementInfo->type); + doCanonicalization |= + replacementInfo->type != opOperand->get().getType(); + } else { + // If replaceUnitExtents cannot handle this case, maintain the same + // type, indexing map, and create a set of mappings representing an + // identity matrix. + newInputOutputTypes.push_back(opOperand->get().getType()); + newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); + int64_t origRank = genericOp.getRank(opOperand); + auto maps = llvm::to_vector<8>(llvm::map_range( + llvm::seq(0, origRank), [&](int64_t dim) -> Attribute { + return AffineMapAttr::get( + AffineMap::get(origRank, /*symbolCount = */ 0, + getAffineDimExpr(dim, context), context)); + })); + reassociationMaps.push_back(ArrayAttr::get(context, maps)); + } } // If the indexing maps of the result operation are not invertible (i.e. not diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -750,4 +750,50 @@ // CHECK: return %[[INIT:.+]] : memref<1xf32> +// ----- +// Test that nothing changes and no assertions are fired for memrefs with affine +// maps while still changing the other operations. + +#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)> + +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> ()>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func @input_stays_same(%arg0 : memref, %arg1 : f32, %shape: memref) -> memref { + linalg.generic #trait + ins(%arg0, %arg1 : memref, f32) + outs(%shape : memref) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) : + linalg.yield %arg3 : f32 + } + return %shape : memref +} +// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)> +// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> +// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> +// CHECK: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: builtin.func @input_stays_same( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref) +// CHECK-SAME -> memref { +// CHECK: %[[OUT:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1], [2, 3], [4]] +// CHECK-SAME: : memref into memref +// CHECK: linalg.generic +// CHECK-SAME: {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref, f32) +// CHECK-SAME: outs(%[[OUT]] : memref) { +// CHECK: ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32): // no predecessors +// CHECK: linalg.yield %[[ARG]] : f32 +// CHECK: } +// CHECK: return %[[ARG2]] : memref