diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -361,11 +361,10 @@ if (!cst || cst.getValue() != 0) return op.emitOpError("expected indexing_map #") << idx << " to be 0 to match 0-D view: " << view; - } - - if (m.getNumResults() != view.getRank()) + } else if (m.getNumResults() != view.getRank()) { return op.emitOpError("expected indexing_map #") << idx << " results to match view rank: " << view; + } } auto concatMap = concatAffineMaps(indexingMaps); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -238,9 +238,14 @@ // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs)); - indexedValues[i] = std_load(genericOp.getInput(i), indexing); + Value input = genericOp.getInput(i); + if (!input.getType().cast().getRank()) { + indexedValues[i] = std_load(input); + } else { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getInputIndexingMap(i), allIvs)); + indexedValues[i] = std_load(input, indexing); + } } // 1.b. Emit std_load from output views. diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -351,12 +351,12 @@ AffineMap mlir::concatAffineMaps(ArrayRef maps) { unsigned numResults = 0; for (auto m : maps) - numResults += m ? m.getNumResults() : 0; + numResults += (m && !m.isSingleConstant()) ? m.getNumResults() : 0; unsigned numDims = 0; SmallVector results; results.reserve(numResults); for (auto m : maps) { - if (!m) + if (!m || m.isSingleConstant()) continue; assert(m.getNumSymbols() == 0 && "expected map without symbols"); results.append(m.getResults().begin(), m.getResults().end()); diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -356,3 +356,36 @@ // CHECK: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32 // CHECK: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]] // CHECK: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]] + +// ----- + + +#broadcast_access = [ + affine_map<(i, j) -> (0)>, + affine_map<(i,j) -> (i,j)> +] + +#trait_broadcast = { + args_in = 1, + args_out = 1, + indexing_maps = #broadcast_access, + iterator_types = ["parallel", "parallel"], + library_call = "some_broadcast_external_fn" +} + +func @generic_op_zero_rank(%arg0 : memref, %arg1: memref<3x4xf32>) +{ + linalg.generic #trait_broadcast %arg0, %arg1 { + ^bb(%a: f32, %b : f32) : + linalg.yield %a : f32 + } : memref, memref<3x4xf32> + return +} + +// CHECK-LABEL: @generic_op_zero_rank +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32> +// CHECK: loop.for %[[i:.*]] = {{.*}} +// CHECK: loop.for %[[j:.*]] = {{.*}} +// CHECK: %[[a:.*]] = load %[[ARG0]][] +// CHECK: store %[[a]], %[[ARG1]][%[[i]], %[[j]]] diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -345,6 +345,30 @@ // ----- +#broadcast_access = [ + affine_map<(i, j) -> (0)>, + affine_map<(i,j) -> (i,j)> +] + +#trait_broadcast = { + args_in = 1, + args_out = 1, + indexing_maps = #broadcast_access, + iterator_types = ["parallel", "parallel"], + library_call = "some_broadcast_external_fn" +} + +func @generic_op_zero_rank(%arg0 : tensor) -> (tensor<3x4xf32>) +{ + %0 = linalg.generic #trait_broadcast %arg0 { + ^bb(%a: f32) : + linalg.yield %a : f32 + } : tensor -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// ----- + // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>