diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -187,54 +187,60 @@ } template -static LogicalResult verifyBlockArgs(GenericOpType op, Block &block); - -template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { - auto nOperands = op.getNumOperands(); - if (block.getNumArguments() != nOperands) - return op.emitOpError("expected number of block arguments to match number " - "of operands"); - +static LogicalResult verifyBlockArgsWithOffset(GenericOpType op, Block &block, + int offset) { // Note: the number and type of yield values are checked in the YieldOp. + auto nOperands = op.getNumOperands(); auto nInputViews = op.getNumInputs(); for (unsigned i = 0; i < nOperands; ++i) { + int idx = i + offset; auto viewType = op.getShapedType(i); - if (viewType.getElementType() != block.getArgument(i).getType()) + if (viewType.getElementType() != block.getArgument(idx).getType()) return op.emitOpError("expected block argument ") - << (i + 1) << " of the same type as elemental type of " + << (idx + 1) << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") << "operand: " << viewType; } + + auto nResults = op.getNumResults(); + for (unsigned i = 0; i < nResults; ++i) { + int shapedTypeIndex = i + nOperands; + int blockArgIndex = shapedTypeIndex + offset; + auto viewType = op.getShapedType(shapedTypeIndex); + if (viewType.getElementType() != block.getArgument(blockArgIndex).getType()) + return op.emitOpError("expected block argument ") + << (blockArgIndex + 1) + << " of the same type as elemental type of result " + << "operand: " << viewType; + } return success(); } +template +static LogicalResult verifyBlockArgs(GenericOpType op, Block &block); + +template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { + auto nOperands = op.getNumOperands(); + auto nResults = op.getNumResults(); + if (block.getNumArguments() != nOperands + nResults) + return op.emitOpError("expected number of block arguments to match number " + "of operands + number of results"); + return verifyBlockArgsWithOffset(op, block, /*offset=*/0); +} + template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { - auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); auto nOperands = op.getNumOperands(); - if (block.getNumArguments() != nOperands + nLoops) + auto nResults = op.getNumResults(); + if (block.getNumArguments() != nOperands + nLoops + nResults) return op.emitOpError( "expected number of block arguments to match number of operands + " - "number of loops"); - - // Note: the number and type of yield values are checked in the YieldOp. + "number of loops + number of results"); for (unsigned i = 0; i < nLoops; ++i) if (!block.getArgument(i).getType().isIndex()) return op.emitOpError("expected block argument ") << (i + 1) << " to be an index"; - - for (unsigned i = 0; i < nOperands; ++i) { - unsigned memrefArgIndex = i + nLoops; - auto viewType = op.getShapedType(i); - if (viewType.getElementType() != - block.getArgument(memrefArgIndex).getType()) - return op.emitOpError("expected block argument ") - << (memrefArgIndex + 1) - << " of the same type as elemental type of " - << ((i < nInputViews) ? "input " : "output ") - << "operand: " << viewType; - } - return success(); + return verifyBlockArgsWithOffset(op, block, nLoops); } template @@ -259,27 +265,48 @@ return success(); } +template +static LogicalResult +verifyFuncArgsWithOffset(GenericOpType op, FunctionType funType, int offset) { + // linalg.generic operands element types are exactly the first function + // arguments. + auto nOperands = op.getNumOperands(); + for (unsigned i = 0; i < nOperands; ++i) { + int funcArgIndex = i + offset; + ShapedType shapedType = op.getShapedType(i); + if (funType.getInput(funcArgIndex) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (funcArgIndex + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of input " << (i + 1); + } + + auto nResults = op.getNumResults(); + for (unsigned i = 0; i < nResults; ++i) { + int shapedTypeIndex = i + nOperands; + int funcArgIndex = shapedTypeIndex + offset; + ShapedType shapedType = op.getShapedType(shapedTypeIndex); + if (funType.getInput(funcArgIndex) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (funcArgIndex + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of result " << (i + 1); + } + + return success(); +} + template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) { auto nOperands = op.getNumOperands(); - if (funType.getNumInputs() != nOperands) + auto nResults = op.getNumResults(); + if (funType.getNumInputs() != nOperands + nResults) return op.emitOpError( - "expected function arguments to match number of operands"); + "expected function arguments to match number of operands + number of " + "results"); if (funType.getNumResults() != op.getNumOutputs()) return op.emitOpError("expected function results(") << funType.getNumResults() << ") to match number of outputs(" << op.getNumOutputs() << ")"; - // linalg.generic operands element types are exactly the first function - // arguments. - for (unsigned idx = 0; idx < nOperands; ++idx) { - ShapedType shapedType = op.getShapedType(idx); - if (funType.getInput(idx) != shapedType.getElementType()) - return op.emitOpError("expected function argument ") - << (idx + 1) << " of the same type as elemental type " - << shapedType.getElementType() << " of operand " << (idx + 1); - } - - return success(); + return verifyFuncArgsWithOffset(op, funType, /*offset=*/0); } template <> @@ -287,9 +314,10 @@ auto nLoops = op.getNumLoops(); auto nOutputs = op.getNumOutputs(); auto nOperands = op.getNumOperands(); - if (funType.getNumInputs() != nOperands + nLoops) + auto nResults = op.getNumResults(); + if (funType.getNumInputs() != nOperands + nLoops + nResults) return op.emitOpError("expected function arguments to match number of " - "loops + number of operands"); + "loops + number of operands + number of results"); if (funType.getNumResults() != nOutputs) return op.emitOpError( "expected function results to match number of outputs"); @@ -298,17 +326,7 @@ return op.emitOpError("expected function argument ") << (i + 1) << " to be an index"; - // linalg.generic operands element types are exactly the first function - // arguments. - for (unsigned idx = 0; idx < nOperands; ++idx) { - ShapedType shapedType = op.getShapedType(idx); - if (funType.getInput(idx + nLoops) != shapedType.getElementType()) - return op.emitOpError("expected function argument ") - << (idx + nLoops + 1) << " of the same type as elemental type " - << shapedType.getElementType() << " of input " << (idx + 1); - } - - return success(); + return verifyFuncArgsWithOffset(op, funType, nLoops); } template diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -441,9 +441,14 @@ for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) { if (consumerOpArg.index() == consumerIdx) { // Map the arguments for the args from the producer. - for (auto producerOpArg : producerOpBlock.getArguments()) - mapper.map(producerOpArg, - fusedBlock->addArgument(producerOpArg.getType())); + for (auto producerOpArg : + llvm::enumerate(producerOpBlock.getArguments())) { + // Skip the operands corresponding to the results. + if (producerOpArg.index() >= producerOp.getNumInputs()) + continue; + mapper.map(producerOpArg.value(), + fusedBlock->addArgument(producerOpArg.value().getType())); + } continue; } mapper.map(consumerOpArg.value(), diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -7,7 +7,7 @@ func @add_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 }: tensor, tensor -> tensor @@ -18,12 +18,12 @@ // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]] // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]] // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]] - ^bb0(%arg5: f32, %arg6: f32): // no predecessors + ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors // CHECK: [[T1:%[a-zA-Z0-9_]*]] = addf [[ARG0]], [[ARG1]] // CHECK-NOT: linalg.yield // CHECK: mulf [[T1]], [[ARG2]] // CHECK: linalg.yield - %3 = mulf %arg5, %arg6 : f32 + %3 = mulf %arg6, %arg7 : f32 linalg.yield %3 : f32 }: tensor, tensor -> tensor return %2 : tensor @@ -40,15 +40,15 @@ func @transpose_add_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 }: tensor, tensor -> tensor // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 // CHECK-SAME: indexing_maps = {{\[}}[[MAP0]], [[MAP1]], [[MAP0]], [[MAP0]]{{\]}} %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { - ^bb0(%arg5: f32, %arg6: f32): // no predecessors - %3 = mulf %arg5, %arg6 : f32 + ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors + %3 = mulf %arg6, %arg7 : f32 linalg.yield %3 : f32 }: tensor, tensor -> tensor return %2 : tensor @@ -65,15 +65,15 @@ func @add_transpose_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 }: tensor, tensor -> tensor // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 // CHECK-SAME: indexing_maps = {{\[}}[[MAP1]], [[MAP0]], [[MAP0]], [[MAP0]]{{\]}} %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { - ^bb0(%arg5: f32, %arg6: f32): // no predecessors - %3 = mulf %arg5, %arg6 : f32 + ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors + %3 = mulf %arg6, %arg7 : f32 linalg.yield %3 : f32 }: tensor, tensor -> tensor return %2 : tensor @@ -92,15 +92,15 @@ func @add_broadcast_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]} %arg0, %arg1 { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 }: tensor, tensor -> tensor // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 // CHECK-SAME: indexing_maps = {{\[}}[[MAP1]], [[MAP1]], [[MAP0]], [[MAP0]] %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { - ^bb0(%arg5: f32, %arg6: f32): // no predecessors - %3 = mulf %arg5, %arg6 : f32 + ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors + %3 = mulf %arg6, %arg7 : f32 linalg.yield %3 : f32 }: tensor, tensor -> tensor return %2 : tensor diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -126,7 +126,7 @@ func @foo(%0: i32, %1: i32, %2: i32) { return } func @generic_mismatched_num_returns(%0: memref, %1: memref) { - // expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of operand 2}} + // expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of input 2}} linalg.generic { args_in = 3, args_out = 0, @@ -219,7 +219,7 @@ } func @generic_fun_arg_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected function argument 1 of the same type as elemental type 'f32' of operand 1}} + // expected-error @+1 {{op expected function argument 1 of the same type as elemental type 'f32' of input 1}} linalg.generic { args_in = 0, args_out = 1, diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -275,7 +275,7 @@ // ----- -func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 { +func @foo(%0: vector<3x4xi4>, %1: f32, %2: f32) -> f32 { %f0 = constant 0.0 : f32 return %f0 : f32 } @@ -294,14 +294,14 @@ library_call = "some_external_function_name_1" } -func @generic_with_tensor_input_and_output( +func @generic_function_with_tensor_input_and_output( %arg0: tensor>, %arg1: tensor) -> (tensor) { %0 = linalg.generic #trait2 %arg0, %arg1 {foo = 1} : tensor>, tensor -> tensor return %0 : tensor } -// CHECK-LABEL: func @generic_with_tensor_input_and_output +// CHECK-LABEL: func @generic_function_with_tensor_input_and_output // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo, // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: @@ -310,6 +310,40 @@ // ----- +#accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i)> +] + +#trait = { + args_in = 1, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel", "reduction"], + library_call = "some_external_function_name_1" +} + +func @generic_block_with_tensor_input_and_output( + %arg0: tensor<2x4xf32>, %arg1: tensor<2xf32>) -> (tensor<2xf32>) { + %0 = linalg.generic #trait %arg0 { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + %res = addf %arg2, %arg3 : f32 + linalg.yield %res : f32 + }: tensor<2x4xf32> -> tensor<2xf32> + return %0 : tensor<2xf32> +} +// CHECK-LABEL: func @generic_block_with_tensor_input_and_output +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "reduction"] +// CHECK-SAME: } %{{.*}} { +// CHECK: ^bb0([[in:%.*]]: f32, [[out:%.*]]: f32): +// CHECK: [[res:%.*]] = addf [[in]], [[out]] : f32 +// CHECK: linalg.yield [[res]] : f32 +// CHECK: }: tensor<2x4xf32> -> tensor<2xf32> +// CHECK: return {{.*}} : tensor<2xf32> + +// ----- + // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>