diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -719,27 +719,6 @@ << ")"; } - // Output tensor indexing map may not depend on reduction indices. - for (OpOperand *opOperand : linalgOp.getOutputOperands()) { - AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); - for (AffineExpr expr : indexingMap.getResults()) { - for (unsigned pos : redDims) { - if (expr.isFunctionOfDim(pos)) { - std::string exprStr; - { - llvm::raw_string_ostream os(exprStr); - os << expr; - } - return op->emitOpError( - "unexpected output tensor expression in indexing map #") - << (opOperand->getOperandNumber() - linalgOp.getNumInputs()) - << " a.k.a '" << exprStr - << "' is function of reduction iterator 'd" << pos << "'"; - } - } - } - } - // Check the region has exactly one block. if (linalgOp->getNumRegions() != 1 || !llvm::hasSingleElement(linalgOp->getRegion(0))) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -539,6 +539,10 @@ return failure(); } for (OpOperand *opOperand : op.getOutputOperands()) { + AffineMap indexingMap = op.getTiedIndexingMap(opOperand); + if (indexingMap.isPermutation()) + continue; + Operation *reduceOp = matchLinalgReduction(opOperand); if (!reduceOp || !getCombinerOpKind(reduceOp)) { LDBG("reduction precondition failed: reduction detection failed"); diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -269,21 +269,6 @@ // ----- -func.func @generic_result_tensor_type(%arg0: memref(off + i)>>, - %arg1: tensor) { - // expected-error @+1 {{unexpected output tensor expression in indexing map #0 a.k.a 'd0' is function of reduction iterator 'd0'}} - %0 = linalg.generic { - indexing_maps = [ affine_map<(i) -> (i)> , affine_map<(i) -> (i)> ], - iterator_types = ["reduction"]} - ins(%arg0 : memref(off + i)>>) - outs(%arg1 : tensor) { - ^bb(%i: f32, %j: f32): - linalg.yield %i: f32 - } -> tensor -} - -// ----- - func.func @generic(%arg0: memref) { // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32}} linalg.generic { diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -350,3 +350,24 @@ return %1 : tensor } // CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor) -> tensor + +// ----- + +func.func @mixed_parallel_reduced_results(%arg0 : tensor, + %arg1 : tensor, %arg2 : tensor, %arg3 : tensor) -> + (tensor, tensor) { + %0:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2, %arg3 : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32): + %1 = arith.mulf %b0, %b1 : f32 + %2 = arith.addf %1, %b3 : f32 + linalg.yield %1, %2 : f32, f32 + } -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} +// CHECK-LABEL: func @mixed_parallel_reduced_results +// CHECK: linalg.generic diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1102,3 +1102,35 @@ } -> tensor<6x6x3x3xf32> return %result : tensor<6x6x3x3xf32> } + +// ----- + +// Check vectorization can handle cases where outputs are a mix of reduced and non-reduced values. +func.func @mixed_parallel_reduced_results(%arg0 : tensor<2x4x8xf32>, + %arg1 : tensor<2x4xf32>, %arg2 : tensor<2x4x8xf32>, %arg3 : tensor<2x4xf32>) -> + (tensor<2x4x8xf32>, tensor<2x4xf32>) { + %0:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : tensor<2x4x8xf32>, tensor<2x4xf32>) + outs(%arg2, %arg3 : tensor<2x4x8xf32>, tensor<2x4xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32): + %1 = arith.mulf %b0, %b1 : f32 + %2 = arith.addf %1, %b3 : f32 + linalg.yield %1, %2 : f32, f32 + } -> (tensor<2x4x8xf32>, tensor<2x4xf32>) + return %0#0, %0#1 : tensor<2x4x8xf32>, tensor<2x4xf32> +} +// CHECK-LABEL: func @mixed_parallel_reduced_results( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x4x8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x4xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<2x4x8xf32> +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<2x4xf32> +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[ARG0]] +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[ARG1]] +// CHECK-DAG: %[[V2:.+]] = vector.transfer_read %[[ARG3]] +// CHECK-DAG: %[[MUL:.+]] = arith.mulf %[[V0]], %[[V1]] +// CHECK-DAG: %[[ADD:.+]] = vector.multi_reduction , %[[MUL]], %[[V2]] +// CHECK-DAG: vector.transfer_write %[[MUL]], %[[ARG2]] +// CHECK-DAG: vector.transfer_write %[[ADD]], %[[ARG3]] \ No newline at end of file