diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -157,7 +157,7 @@ AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); llvm::DenseMap domainDimToOperandDim; SmallVector exprs(origIndexingMap.getResults()); - if (genericOp.isScalar(opOperand)) + if (genericOp.isScalar(opOperand) || exprs.size() == 0) return std::make_tuple(opOperand->get(), AffineMap::get(numLoops, 0, exprs, b.getContext())); diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -652,3 +652,30 @@ // CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]] // CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] // CHECK-SAME: into %[[ALLOC]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> ()> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +func.func @scalar_tensor(%arg0 : tensor) -> tensor<1x32x7x7x32xf32> { + %empty_gen = tensor.empty() : tensor<1x7x7x1024xf32> + %gen = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor) outs(%empty_gen : tensor<1x7x7x1024xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x7x7x1024xf32> + %empty_pack = tensor.empty() : tensor<1x32x7x7x32xf32> + %pack = tensor.pack %gen outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %empty_pack : tensor<1x7x7x1024xf32> -> tensor<1x32x7x7x32xf32> + return %pack : tensor<1x32x7x7x32xf32> +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK: func.func @scalar_tensor +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x32x7x7x32xf32> +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]]