diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -579,8 +579,9 @@ OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds, "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayAttr":$iteratorTypes, - CArg<"function_ref", - "nullptr">:$bodyBuilderFn)>, + CArg<"function_ref", + "nullptr">:$bodyBuilderFn)>, ]; let extraClassDeclaration = [{ @@ -588,7 +589,13 @@ unsigned getNumControlOperands() { return 3 * getNumLoops(); } ValueRange getInductionVars() { - return getBody()->getArguments(); + return getBody()->getArguments().take_front(getNumLoops()); + } + ValueRange getRegionInputArgs() { + return getBody()->getArguments().slice(getNumLoops(), inputs().size()); + } + ValueRange getRegionOutputArgs() { + return getBody()->getArguments().take_back(outputs().size()); } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -834,6 +834,22 @@ parseOptionalAssignmentList(SmallVectorImpl &lhs, SmallVectorImpl &rhs) = 0; + /// Parse a list of assignments of the form + /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...) + ParseResult parseAssignmentListWithTypes(SmallVectorImpl &lhs, + SmallVectorImpl &rhs, + SmallVectorImpl &types) { + OptionalParseResult result = + parseOptionalAssignmentListWithTypes(lhs, rhs, types); + if (!result.hasValue()) + return emitError(getCurrentLocation(), "expected '('"); + return result.getValue(); + } + + virtual OptionalParseResult + parseOptionalAssignmentListWithTypes(SmallVectorImpl &lhs, + SmallVectorImpl &rhs, + SmallVectorImpl &types) = 0; /// Parse a keyword followed by a type. ParseResult parseKeywordType(const char *keyword, Type &result) { return failure(parseKeyword(keyword) || parseType(result)); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1805,11 +1805,13 @@ // TiledLoopOp //===----------------------------------------------------------------------===// -void TiledLoopOp::build( - OpBuilder &builder, OperationState &result, ValueRange lowerBounds, - ValueRange upperBounds, ValueRange steps, ValueRange inputs, - ValueRange outputs, ArrayAttr iteratorTypes, - function_ref bodyBuilderFn) { +void TiledLoopOp::build(OpBuilder &builder, OperationState &result, + ValueRange lowerBounds, ValueRange upperBounds, + ValueRange steps, ValueRange inputs, ValueRange outputs, + ArrayAttr iteratorTypes, + function_ref + bodyBuilderFn) { result.addOperands(lowerBounds); result.addOperands(upperBounds); result.addOperands(steps); @@ -1834,25 +1836,46 @@ OpBuilder::InsertionGuard guard(builder); unsigned numIVs = steps.size(); SmallVector argTypes(numIVs, builder.getIndexType()); + for (Type type : TypeRange(inputs)) + argTypes.push_back(type); + for (Type type : TypeRange(outputs)) + argTypes.push_back(type); Region *bodyRegion = result.addRegion(); Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes); if (bodyBuilderFn) { builder.setInsertionPointToStart(bodyBlock); - bodyBuilderFn(builder, result.location, bodyBlock->getArguments()); + bodyBuilderFn(builder, result.location, + bodyBlock->getArguments().take_front(numIVs), + bodyBlock->getArguments().slice(numIVs, inputs.size()), + bodyBlock->getArguments().take_back(outputs.size())); TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location); } } static void print(OpAsmPrinter &p, TiledLoopOp op) { - p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (" + p << op.getOperationName() << " (" << op.getInductionVars() << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step() << ")"; - if (!op.inputs().empty()) - p << " ins (" << op.inputs() << ": " << TypeRange(op.inputs()) << ")"; - if (!op.outputs().empty()) - p << " outs (" << op.outputs() << ":" << TypeRange(op.outputs()) << ")"; + if (!op.inputs().empty()) { + p << " ins ("; + llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p, + [&](auto it) { + p << std::get<0>(it) << " = " << std::get<1>(it) + << ": " << std::get<1>(it).getType(); + }); + p << ")"; + } + if (!op.outputs().empty()) { + p << " outs ("; + llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p, + [&](auto it) { + p << std::get<0>(it) << " = " << std::get<1>(it) + << ": " << std::get<1>(it).getType(); + }); + p << ")"; + } if (llvm::any_of(op.iterator_types(), [](Attribute attr) { return attr.cast().getValue() != @@ -1900,13 +1923,13 @@ return failure(); // Parse input tensors. - SmallVector inputs; + SmallVector inputs, input_region_args; + SmallVector inputTypes; if (succeeded(parser.parseOptionalKeyword("ins"))) { - SmallVector inputTypes; llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseLParen() || parser.parseOperandList(inputs) || - parser.parseColonTypeList(inputTypes) || parser.parseRParen()) + if (parser.parseAssignmentListWithTypes(input_region_args, inputs, + inputTypes)) return failure(); if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc, @@ -1915,13 +1938,13 @@ } // Parse output tensors. - SmallVector outputs; + SmallVector outputs, output_region_args; + SmallVector outputTypes; if (succeeded(parser.parseOptionalKeyword("outs"))) { - SmallVector outputTypes; llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseLParen() || parser.parseOperandList(outputs) || - parser.parseColonTypeList(outputTypes) || parser.parseRParen()) + if (parser.parseAssignmentListWithTypes(output_region_args, outputs, + outputTypes)) return failure(); if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc, @@ -1963,8 +1986,16 @@ // Parse the body. Region *body = result.addRegion(); - SmallVector types(ivs.size(), builder.getIndexType()); - if (parser.parseRegion(*body, ivs, types)) + + SmallVector region_types(ivs.size(), builder.getIndexType()); + region_types.append(inputTypes); + region_types.append(outputTypes); + + SmallVector region_args(ivs); + region_args.append(input_region_args); + region_args.append(output_region_args); + + if (parser.parseRegion(*body, region_args, region_types)) return failure(); // Parse optional attributes. @@ -1991,6 +2022,33 @@ return op.emitOpError("expected iterator types array attribute size = ") << op.iterator_types().size() << " to match the number of loops = " << op.getNumLoops(); + + // Check if types of input arguments match region args types. + for (auto &item : + llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) { + Value input, inputRegionArg; + unsigned index = item.index(); + std::tie(input, inputRegionArg) = item.value(); + if (input.getType() != inputRegionArg.getType()) + return op.emitOpError("expected input arg ") + << index << " with type = " << input.getType() + << " to match region arg " << index + op.getNumLoops() + << " type = " << inputRegionArg.getType(); + } + + // Check if types of input arguments match region args types. + for (auto &item : + llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) { + Value output, outputRegionArg; + unsigned index = item.index(); + std::tie(output, outputRegionArg) = item.value(); + if (output.getType() != outputRegionArg.getType()) + return op.emitOpError("expected output arg ") + << index << " with type = " << output.getType() + << " to match region arg " + << index + op.getNumLoops() + op.inputs().size() + << " type = " << outputRegionArg.getType(); + } return success(); } @@ -2002,14 +2060,15 @@ // // Example: // -// %0 = linalg.tiled_loop ... outs (%out, %out_buf:tensor<...>, memref<...>) { +// %0 = linalg.tiled_loop ... outs (%o_ = %out: tensor<...>, +// %obuf_ = %out_buf: memref<...>) { // ... -// linalg.yield %out : tensor ... +// linalg.yield %o_ : tensor ... // } // // Becomes // -// linalg.tiled_loop ... outs (%out_buf:memref<...>) { +// linalg.tiled_loop ... outs (%obuf_ = %out_buf: memref<...>) { // ... // linalg.yield // } @@ -2026,16 +2085,27 @@ // Match the pattern and collect output buffers that will replace the output // tensors and also the ops that will be ignored when cloning the body. - SmallVector newOutputOperands, newYieldArgs; + SmallVector newOutputOperands, newYieldArgs, + regionOutputTensorArgs; int resultId = 0; - for (Value out : tiledLoop.outputs()) { + // Store ids of the corresponding old and new output operands. + SmallVector, 2> old_out_id_to_new; + for (auto item : llvm::enumerate( + llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) { + size_t index = item.index(); + Value out = std::get<0>(item.value()); + Value outRegionArg = std::get<1>(item.value()); + if (!out.getType().isa()) { + old_out_id_to_new.push_back({index, newOutputOperands.size()}); newOutputOperands.push_back(out); + regionOutputTensorArgs.push_back(outRegionArg); continue; } Value result = tiledLoop.getResult(resultId); Value yieldArg = yieldOp.getOperand(resultId); - if (yieldArg != out || !result.use_empty()) { + if (yieldArg != outRegionArg || !result.use_empty()) { + old_out_id_to_new.push_back({index, newOutputOperands.size()}); newOutputOperands.push_back(out); newYieldArgs.push_back(yieldArg); } @@ -2053,6 +2123,10 @@ // unnecessary `subtensor_insert`, `tensor_load` and `cast` ops. BlockAndValueMapping bvm; bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); + bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs()); + for (const auto &item : old_out_id_to_new) + bvm.map(tiledLoop.getRegionOutputArgs()[item.first], + newTiledLoop.getRegionOutputArgs()[item.second]); OpBuilder innerBuilder = OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); for (auto &op : tiledLoop.getBody()->without_terminator()) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1694,6 +1694,29 @@ return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); } + /// Parse a list of assignments of the form + /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). + OptionalParseResult + parseOptionalAssignmentListWithTypes(SmallVectorImpl &lhs, + SmallVectorImpl &rhs, + SmallVectorImpl &types) override { + if (failed(parseOptionalLParen())) + return llvm::None; + + auto parseElt = [&]() -> ParseResult { + OperandType regionArg, operand; + Type type; + if (parseRegionArgument(regionArg) || parseEqual() || + parseOperand(operand) || parseColon() || parseType(type)) + return failure(); + lhs.push_back(regionArg); + rhs.push_back(operand); + types.push_back(type); + return success(); + }; + return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); + } + private: /// The source location of the operation name. SMLoc nameLoc; diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -861,10 +861,12 @@ %c192 = constant 192 : index %useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) - ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>) - outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) { - call @foo(%A, %B, %C) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () - linalg.yield %C_tensor : tensor<192x192xf32> + ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) + outs (%CT_ = %C_tensor: tensor<192x192xf32>, + %C_ = %C: memref<192x192xf32>) { + call @foo(%A_, %B_, %C_) + : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () + linalg.yield %CT_ : tensor<192x192xf32> } return } @@ -880,9 +882,9 @@ // CHECK-NOT: %{{.*}} = linalg.tiled_loop // CHECK: linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]]) // CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]]) -// CHECK-SAME: ins (%[[A]], %[[B]]: memref<192x192xf32>, memref<192x192xf32>) -// CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) { -// CHECK-NEXT: call @foo(%[[A]], %[[B]], %[[C]]) +// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: memref<192x192xf32>, %[[B_:.*]] = %[[B]]: memref<192x192xf32>) +// CHECK-SAME: outs (%[[C_:.*]] = %[[C]]: memref<192x192xf32>) { +// CHECK-NEXT: call @foo(%[[A_]], %[[B_]], %[[C_]]) // CHECK-NEXT: linalg.yield // ----- diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -776,9 +776,10 @@ %c192 = constant 192 : index %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c24) - ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>) - outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) { - call @foo(%A, %B, %C) + ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) + outs (%CT_ = %C_tensor: tensor<192x192xf32>, + %C_ = %C: memref<192x192xf32>) { + call @foo(%A_, %B_, %C_) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () // expected-error @+1 {{expected number of tensor output args = 1 to match the number of yield operands = 0}} linalg.yield @@ -803,9 +804,10 @@ %c192 = constant 192 : index %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c24) - ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>) - outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) { - %1 = call @foo(%A, %B, %C) + ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) + outs (%CT_ = %C_tensor: tensor<192x192xf32>, + %C_ = %C: memref<192x192xf32>) { + %1 = call @foo(%A_, %B_, %C_) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> tensor // expected-error @+1 {{expected yield operand 0 with type = 'tensor' to match output arg type = 'tensor<192x192xf32>}} linalg.yield %1 : tensor @@ -815,10 +817,6 @@ // ----- -#map0 = affine_map<(d0) -> (24, -d0 + 192)> -#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> -#map2 = affine_map<(d0) -> (16, -d0 + 192)> - func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, %C: memref<192x192xf32>) -> () @@ -830,10 +828,12 @@ %c192 = constant 192 : index // expected-error @+1 {{expected iterator types array attribute size = 1 to match the number of loops = 2}} %0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ( { - ^bb0(%arg4: index, %arg5: index): // no predecessors - call @foo(%A, %B, %C) + ^bb0(%arg4: index, %arg5: index, %A_: memref<192x192xf32>, + %B_: memref<192x192xf32>, %CT_: tensor<192x192xf32>, + %C_: memref<192x192xf32>): + call @foo(%A_, %B_, %C_) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () - linalg.yield %C_tensor : tensor<192x192xf32> + linalg.yield %CT_ : tensor<192x192xf32> }) { iterator_types = ["parallel"], operand_segment_sizes = dense<2> : vector<5xi32> @@ -842,3 +842,23 @@ ) -> tensor<192x192xf32> return } + +// ----- + +func private @foo(%A: memref<100xf32>) -> () + +func @tiled_loop_incorrent_block_arg_type(%A: memref<192xf32>) { + %c0 = constant 0 : index + %c192 = constant 192 : index + %c24 = constant 24 : index + // expected-error @+1 {{expected output arg 0 with type = 'memref<192xf32>' to match region arg 1 type = 'memref<100xf32>'}} + "linalg.tiled_loop"(%c0, %c192, %c24, %A) ( { + ^bb0(%arg4: index, %A_: memref<100xf32>): + call @foo(%A_) : (memref<100xf32>)-> () + linalg.yield + }) { + iterator_types = ["parallel"], + operand_segment_sizes = dense<[1, 1, 1, 0, 1]> : vector<5xi32> + } : (index, index, index, memref<192xf32>) -> () + return +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -804,13 +804,13 @@ %c24 = constant 24 : index %c64 = constant 64 : index %prod = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) - ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>) - outs(%out : tensor<24x64xi8>) { - %lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1] + ins(%lhs_ = %lhs: tensor<24x64xi8>, %rhs_ = %rhs: tensor<24x64xi8>) + outs(%out_ = %out: tensor<24x64xi8>) { + %lhs_sub = subtensor %lhs_[%i, 0] [%c4, %c64] [1, 1] : tensor<24x64xi8> to tensor - %rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1] + %rhs_sub = subtensor %rhs_[%i, 0] [%c4, %c64] [1, 1] : tensor<24x64xi8> to tensor - %out_sub = subtensor %out[%i, 0] [%c4, %c64] [1, 1] + %out_sub = subtensor %out_[%i, 0] [%c4, %c64] [1, 1] : tensor<24x64xi8> to tensor %sum = linalg.generic #trait_4 @@ -821,7 +821,7 @@ linalg.yield %s : i8 } -> tensor - %sum_sub = subtensor_insert %sum into %out[%i, 0][%c4, %c64][1, 1] + %sum_sub = subtensor_insert %sum into %out_[%i, 0][%c4, %c64][1, 1] : tensor into tensor<24x64xi8> linalg.yield %sum_sub : tensor<24x64xi8> } @@ -860,16 +860,18 @@ %Z = memref.dim %input_3d, %c2 : tensor<16x24x32xf32> %result = linalg.tiled_loop (%i, %j, %k) = (%c0, %c0, %c0) to (%X, %Y, %Z) step (%c2, %c4, %c8) - ins(%input_3d, %input_2d: tensor<16x24x32xf32>, tensor<16x32xf32>) - outs( %output: tensor<24xf32>) + ins(%i3d_ = %input_3d: tensor<16x24x32xf32>, + %i2d_ = %input_2d: tensor<16x32xf32>, + %i1d_ = %input_1d: tensor<24xf32>) + outs(%o_ = %output: tensor<24xf32>) iterators["reduction", "parallel", "reduction"] { - %sub_3d = subtensor %input_3d[%i, %j, %k][2, 4, 8][1, 1, 1] + %sub_3d = subtensor %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1] : tensor<16x24x32xf32> to tensor<2x4x8xf32> - %sub_2d = subtensor %input_2d[%i, %k][2, 8][1, 1] + %sub_2d = subtensor %i2d_[%i, %k][2, 8][1, 1] : tensor<16x32xf32> to tensor<2x8xf32> - %sub_1d = subtensor %input_1d[%j] [4] [1] + %sub_1d = subtensor %i1d_[%j] [4] [1] : tensor<24xf32> to tensor<4xf32> - %sub_out = subtensor %output[%j] [4] [1] + %sub_out = subtensor %o_[%j] [4] [1] : tensor<24xf32> to tensor<4xf32> %acc = linalg.generic #trait_5 ins(%sub_3d, %sub_2d, %sub_1d @@ -881,7 +883,7 @@ linalg.yield %1 : f32 } -> tensor<4xf32> - %sum_sub = subtensor_insert %acc into %output[%j][%c4][1] + %sum_sub = subtensor_insert %acc into %o_[%j][%c4][1] : tensor<4xf32> into tensor<24xf32> linalg.yield %sum_sub : tensor<24xf32> } @@ -919,16 +921,18 @@ %Z = memref.dim %input_3d, %c2 : memref<16x24x32xf32> linalg.tiled_loop (%i, %j, %k) = (%c0, %c0, %c0) to (%X, %Y, %Z) step (%c2, %c4, %c8) - ins(%input_3d, %input_2d: memref<16x24x32xf32>, memref<16x32xf32>) - outs( %output: memref<24xf32>) + ins(%i3d_ = %input_3d: memref<16x24x32xf32>, + %i2d_ = %input_2d: memref<16x32xf32>, + %i1d_ = %input_1d: memref<24xf32>) + outs(%o_ = %output: memref<24xf32>) iterators["reduction", "parallel", "reduction"] { - %sub_3d = memref.subview %input_3d[%i, %j, %k][2, 4, 8][1, 1, 1] + %sub_3d = memref.subview %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1] : memref<16x24x32xf32> to memref<2x4x8xf32, #map_1> - %sub_2d = memref.subview %input_2d[%i, %k][2, 8][1, 1] + %sub_2d = memref.subview %i2d_[%i, %k][2, 8][1, 1] : memref<16x32xf32> to memref<2x8xf32, #map_2> - %sub_1d = memref.subview %input_1d[%j] [4] [1] + %sub_1d = memref.subview %i1d_[%j] [4] [1] : memref<24xf32> to memref<4xf32, #map_3> - %sub_out = memref.subview %output[%j] [4] [1] + %sub_out = memref.subview %o_[%j] [4] [1] : memref<24xf32> to memref<4xf32, #map_3> linalg.generic #trait_6 ins(%sub_3d, %sub_2d, %sub_1d