diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -242,21 +242,25 @@ // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { Value input = genericOp.getInput(i); - if (!input.getType().cast().getRank()) { - indexedValues[i] = std_load(input); - } else { + if (input.getType().cast().getRank()) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getInputIndexingMap(i), allIvs)); indexedValues[i] = std_load(input, indexing); + } else { + indexedValues[i] = std_load(input); } } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nInputs + i] = - std_load(genericOp.getOutputBuffer(i), indexing); + Value output = genericOp.getOutputBuffer(i); + if (output.getType().cast().getRank()) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + indexedValues[nInputs + i] = std_load(output, indexing); + } else { + indexedValues[nInputs + i] = std_load(output); + } } auto funcOp = genericOp.getFunction(); @@ -267,9 +271,14 @@ // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing); + Value output = genericOp.getOutputBuffer(i); + if (output.getType().cast().getRank()) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + std_store(callOp->getResult(i), output, indexing); + } else { + std_store(callOp->getResult(i), output); + } } return; } @@ -288,10 +297,15 @@ auto *yieldOp = cast(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), - genericOp.getOutputBuffer(i), indexing); + Value output = genericOp.getOutputBuffer(i); + if (output.getType().cast().getRank()) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + std_store(map.lookup(yieldOp->getOperand(i)), + genericOp.getOutputBuffer(i), indexing); + } else { + std_store(map.lookup(yieldOp->getOperand(i)), output); + } } } }; @@ -348,21 +362,25 @@ // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { Value input = indexedGenericOp.getInput(i); - if (!input.getType().cast().getRank()) { - indexedValues[nLoops + i] = std_load(input); - } else { + if (input.getType().cast().getRank()) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs)); indexedValues[nLoops + i] = std_load(input, indexing); + } else { + indexedValues[nLoops + i] = std_load(input); } } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nLoops + nInputs + i] = - std_load(indexedGenericOp.getOutputBuffer(i), indexing); + Value output = indexedGenericOp.getOutputBuffer(i); + if (output.getType().cast().getRank()) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + indexedValues[nLoops + nInputs + i] = std_load(output, indexing); + } else { + indexedValues[nLoops + nInputs + i] = std_load(output); + } } if (auto funcOp = indexedGenericOp.getFunction()) { @@ -372,10 +390,14 @@ // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i), - indexing); + Value output = indexedGenericOp.getOutputBuffer(i); + if (output.getType().cast().getRank()) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + std_store(callOp->getResult(i), output, indexing); + } else { + std_store(callOp->getResult(i), output); + } } return; } @@ -394,10 +416,14 @@ auto *yieldOp = cast(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), - indexedGenericOp.getOutputBuffer(i), indexing); + Value output = indexedGenericOp.getOutputBuffer(i); + if (output.getType().cast().getRank()) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + std_store(map.lookup(yieldOp->getOperand(i)), output, indexing); + } else { + std_store(map.lookup(yieldOp->getOperand(i)), output); + } } } }; diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -411,3 +411,75 @@ // CHECK: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32 // CHECK: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32 // CHECK: store %[[result]], %[[ARG1]][%[[i]], %[[j]]] + +#reduce_1D_access = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (0)> +] + +#trait_reduce_1D = { + args_in = 1, + args_out = 1, + indexing_maps = #reduce_1D_access, + iterator_types = ["reduction"], + library_call = "some_reduce_external_fn" +} + +func @generic_op_1D_reduce(%arg0: memref, %arg1: memref) +{ + linalg.generic #trait_reduce_1D %arg0, %arg1 { + ^bb(%a: f32, %b: f32) : + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } : memref, memref + return +} +// CHECK-LABEL: @generic_op_1D_reduce +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECK: loop.for %[[i:.*]] = {{.*}} +// CHECK: %[[a:.*]] = load %[[ARG0]][%[[i]]] +// CHECK: %[[b:.*]] = load %[[ARG1]][] +// CHECK: %[[c:.*]] = addf %[[a]], %[[b]] : f32 +// CHECK: store %[[c]], %[[ARG1]][] + + +#reduce_init_1D_access = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (0)>, + affine_map<(i) -> (0)> +] + +#trait_reduce_init_1D = { + args_in = 2, + args_out = 1, + indexing_maps = #reduce_init_1D_access, + iterator_types = ["reduction"], + library_call = "some_reduce_external_fn" +} + +func @indexed_generic_op_1D_reduce(%arg0: memref, + %arg1: memref, + %arg2: memref) +{ + linalg.indexed_generic #trait_reduce_init_1D %arg0, %arg1, %arg2 { + ^bb(%i : index, %a: f32, %b: f32, %c: f32) : + %0 = constant 0 : index + %1 = cmpi "eq", %0, %i : index + %2 = select %1, %b, %c : f32 + %3 = addf %a, %2 : f32 + linalg.yield %3 : f32 + } : memref, memref, memref + return +} +// CHECK-LABEL: @indexed_generic_op_1D_reduce +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref +// CHECK: loop.for %[[i:.*]] = {{.*}} +// CHECK: %[[a:.*]] = load %[[ARG0]][%[[i]]] +// CHECK: %[[b:.*]] = load %[[ARG1]][] +// CHECK: %[[c:.*]] = load %[[ARG2]][] +// CHECK: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]] +// CHECK: %[[e:.*]] = addf %[[a]], %[[d]] +// CHECK: store %[[e]], %[[ARG2]][]