diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -284,6 +284,30 @@ return shape[dim] == 1 && exprs[dim] == zeroExpr; }; + // Provide a struct that represents an unchanged replacement result. This is + // used in place of error conditions. + Type actualType = opOperand->get().getType(); + auto generateDoNotChange = [&]() -> UnitExtentReplacementInfo { + auto reassociationMaps = llvm::to_vector<8>(llvm::map_range( + llvm::seq(0, origRank), [&](int64_t dim) -> Attribute { + return AffineMapAttr::get(AffineMap::get( + origRank, /*symbolCount = */ 0, + llvm::makeArrayRef(getAffineDimExpr(dim, context)), context)); + })); + + return {actualType, + AffineMap::get(indexingMap.getNumDims(), + indexingMap.getNumSymbols(), exprs, context), + ArrayAttr::get(context, reassociationMaps)}; + }; + + // Early return for memrefs with affine maps to represent that we will always + // leave them unchanged. + if (auto memref = actualType.dyn_cast()) { + if (!memref.getAffineMaps().empty()) + return generateDoNotChange(); + } + int64_t dim = 0; // Fold dimensions that are unit-extent at the beginning of the tensor. while (dim < origRank && isUnitExtent(dim)) @@ -302,8 +326,8 @@ reassociations.clear(); ++dim; } + // Compute the tensor or scalar replacement type. - Type actualType = opOperand->get().getType(); Type elementType = getElementTypeOrSelf(opOperand->get()); Type replacementType; if (elementType == opOperand->get().getType()) { diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -750,4 +750,50 @@ // CHECK: return %[[INIT:.+]] : memref<1xf32> +// ----- +// Test that nothing changes and no assertions are fired for memrefs with affine +// maps while still changing the other operations. + +#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)> + +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> ()>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func @input_stays_same(%arg0 : memref, %arg1 : f32, %shape: memref) -> memref { + linalg.generic #trait + ins(%arg0, %arg1 : memref, f32) + outs(%shape : memref) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) : + linalg.yield %arg3 : f32 + } + return %shape : memref +} +// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)> +// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> +// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> +// CHECK: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: builtin.func @input_stays_same( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref) +// CHECK-SAME -> memref { +// CHECK: memref.collapse_shape %[[ARG2]] {{\[}}[0, 1], [2, 3], [4]] +// CHECK-SAME: : memref into memref +// CHECK: linalg.generic +// CHECK-SAME: {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref, f32) +// CHECK-SAME: outs(%0 : memref) { +// CHECK: ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32): // no predecessors +// CHECK: linalg.yield %[[ARG]] : f32 +// CHECK: } +// CHECK: return %[[ARG2]] : memref