diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -293,10 +293,11 @@ replacementType = elementType; } else if (actualType.isa()) { replacementType = RankedTensorType::get(newShape, elementType); - } else if (actualType.isa()) { - replacementType = MemRefType::get(newShape, elementType); + } else { + auto memrefType = actualType.cast(); + replacementType = MemRefType::get(newShape, elementType, {}, + memrefType.getMemorySpaceAsInt()); } - assert(replacementType && "unsupported shaped type"); UnitExtentReplacementInfo info = {replacementType, AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(), diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -847,3 +847,28 @@ } return %res: tensor<4x2xf32> } + +// ----- + +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func.func @drop_all_loops(%arg0 : memref<1x1xf32, 3>) -> memref<1x1xf32, 3> +{ + linalg.generic #trait + ins(%arg0 : memref<1x1xf32, 3>) + outs(%arg0 : memref<1x1xf32, 3>) { + ^bb0(%arg1: f32, %arg2: f32) : + linalg.yield %arg1 : f32 + } + return %arg0 : memref<1x1xf32, 3> +} + +// CHECK-LABEL: func @drop_all_loops +// CHECK: memref.collapse_shape %{{.*}} [] memref<1x1xf32, 3> into memref +// CHECK: linalg.generic{{.*}}memref