diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -344,10 +344,14 @@ // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs)); - indexedValues[nLoops + i] = - std_load(indexedGenericOp.getInput(i), indexing); + Value input = indexedGenericOp.getInput(i); + if (!input.getType().cast().getRank()) { + indexedValues[nLoops + i] = std_load(input); + } else { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs)); + indexedValues[nLoops + i] = std_load(input, indexing); + } } // 1.b. Emit std_load from output views. diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -359,10 +359,9 @@ // ----- - #broadcast_access = [ affine_map<(i, j) -> (0)>, - affine_map<(i,j) -> (i,j)> + affine_map<(i, j) -> (i, j)> ] #trait_broadcast = { @@ -373,10 +372,10 @@ library_call = "some_broadcast_external_fn" } -func @generic_op_zero_rank(%arg0 : memref, %arg1: memref<3x4xf32>) +func @generic_op_zero_rank(%arg0: memref, %arg1: memref<3x4xf32>) { linalg.generic #trait_broadcast %arg0, %arg1 { - ^bb(%a: f32, %b : f32) : + ^bb(%a: f32, %b: f32) : linalg.yield %a : f32 } : memref, memref<3x4xf32> return @@ -389,3 +388,26 @@ // CHECK: loop.for %[[j:.*]] = {{.*}} // CHECK: %[[a:.*]] = load %[[ARG0]][] // CHECK: store %[[a]], %[[ARG1]][%[[i]], %[[j]]] + +func @indexed_generic_op_zero_rank(%arg0: memref, %arg1: memref<3x4xi32>) +{ + linalg.indexed_generic #trait_broadcast %arg0, %arg1 { + ^bb(%i: index, %j: index, %a: i32, %b: i32) : + %ij = addi %i, %j : index + %ij_int = index_cast %ij : index to i32 + %result = addi %a, %ij_int : i32 + linalg.yield %result : i32 + } : memref, memref<3x4xi32> + return +} + +// CHECK-LABEL: @indexed_generic_op_zero_rank +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32> +// CHECK: loop.for %[[i:.*]] = {{.*}} +// CHECK: loop.for %[[j:.*]] = {{.*}} +// CHECK: %[[a:.*]] = load %[[ARG0]][ +// CHECK: %[[ij:.*]] = addi %[[i]], %[[j]] : index +// CHECK: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32 +// CHECK: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32 +// CHECK: store %[[result]], %[[ARG1]][%[[i]], %[[j]]] diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -347,7 +347,7 @@ #broadcast_access = [ affine_map<(i, j) -> (0)>, - affine_map<(i,j) -> (i,j)> + affine_map<(i, j) -> (i, j)> ] #trait_broadcast = { @@ -358,7 +358,7 @@ library_call = "some_broadcast_external_fn" } -func @generic_op_zero_rank(%arg0 : tensor) -> (tensor<3x4xf32>) +func @generic_op_zero_rank(%arg0: tensor) -> (tensor<3x4xf32>) { %0 = linalg.generic #trait_broadcast %arg0 { ^bb(%a: f32) : @@ -367,6 +367,15 @@ return %0 : tensor<3x4xf32> } +func @indexed_generic_op_zero_rank(%arg0: tensor) -> (tensor<3x4xf32>) +{ + %0 = linalg.indexed_generic #trait_broadcast %arg0 { + ^bb(%i: index, %j: index, %a: f32) : + linalg.yield %a : f32 + } : tensor -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + // ----- // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>