diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -48,7 +48,8 @@ /// result in surprising behavior when combined with region definition. Operation *makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, - ArrayRef outputs, + ArrayRef outputBuffers, ArrayRef initTensors, + ArrayRef resultTensorTypes, function_ref regionBuilder = defaultRegionBuilder, ArrayRef otherValues = {}, ArrayRef otherAttributes = {}); @@ -134,18 +135,6 @@ linalg_generic_matmul(Value vA, Value vB, Value vC, MatmulRegionBuilder regionBuilder = macRegionBuilder); -/// Build a linalg.generic, under the current ScopedContext, at the current -/// insert point, that computes: -/// ``` -/// (m, n, k) = (par, par, seq) -/// | -/// | C(m, n) = sum_k(A(m, k) * B(k, n)) -/// ``` -/// and returns the tensor `C`. -Operation * -linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC, - MatmulRegionBuilder regionBuilder = mulRegionBuilder); - /// Build a linalg.generic, under the current ScopedContext, at the current /// insert point, that computes: /// ``` diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -40,11 +40,12 @@ // depending on the specific Linalg op. class LinalgStructuredBase_Op props> : Op { + !listconcat(props, [LinalgStructuredInterface])> { } class LinalgStructured_Op props> - : LinalgStructuredBase_Op { + : LinalgStructuredBase_Op { code libraryCallName = [{ std::string getLibraryCallName() { return generateLibraryCallName(getOperation()); @@ -457,43 +458,53 @@ CPred<"$_self.cast().getRank() == " # rank>] >>; -class GenericOpBase : LinalgStructuredBase_Op]> { - let arguments = (ins Variadic:$views, - I64Attr:$args_in, - I64Attr:$args_out, - AffineMapArrayAttr:$indexing_maps, - ArrayAttr:$iterator_types, - OptionalAttr:$doc, - OptionalAttr:$library_call, - Confined, - [IntMinValue<0>]>:$symbol_source); - let results = (outs Variadic:$output_tensors); +class GenericOpBase : LinalgStructuredBase_Op]> { + let arguments = (ins Variadic:$inputs, + Variadic:$output_buffers, + Variadic:$init_tensors, + AffineMapArrayAttr:$indexing_maps, + ArrayAttr:$iterator_types, + OptionalAttr:$doc, + OptionalAttr:$library_call, + Confined, [IntMinValue<0>]> + :$symbol_source); + let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder< + "OpBuilder &builder, OperationState &result, " + "ValueRange inputs, ValueRange outputBuffers, " + "ArrayRef indexingMaps, ArrayRef iteratorTypes, " + "StringRef = \"\", StringRef = \"\", " + "IntegerAttr = IntegerAttr(), " + "function_ref = nullptr">, + OpBuilder< + "OpBuilder &builder, OperationState &result, ArrayRef resultTensorTypes," + "ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors, " + "ArrayRef indexingMaps, ArrayRef iteratorTypes, " + "StringRef = \"\", StringRef = \"\", IntegerAttr = IntegerAttr(), " + "function_ref = nullptr"> + ]; let extraClassDeclaration = [{ SmallVector linalgTraitAttrNames() { return SmallVector{ - getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(), + getDocAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(), getIteratorTypesAttrName(), getSymbolSourceAttrName() }; } - - unsigned getNumInputs() { return args_in(); } - - unsigned getNumOutputs() { return args_out(); } - StringRef getLibraryCallName() { return library_call().hasValue() ? library_call().getValue() : ""; } - llvm::Optional getSymbolSource() { auto ss = symbol_source(); return ss.hasValue() ? llvm::Optional(ss.getValue()) : llvm::None; } }]; - let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseGenericOp(parser, result); }]; } @@ -502,18 +513,19 @@ def GenericOp : GenericOpBase<"generic"> { let description = [{ Generic Linalg op form where the key properties of the computation are - specified as attributes. In pretty form, a linalg.generic op is written as: + specified as attributes. In pretty form, a `linalg.generic` op is written + as: ```mlir - linalg.generic #trait_attribute %A, %B, %C {other-attributes} : - memref, - memref, - memref + linalg.generic #trait_attribute + ins(%A, %B : memref, + memref) + outs(%C : memref) + [other-attributes] + {region} ``` Where #trait_attributes is an alias of a dictionary attribute containing: - - args_in: an I64Attr representing the number of input (readonly) views - - args_out: an I64Attr representing the number of output (readwrite) views - doc [optional]: a documentation string - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input and output view. Such AffineMapAttr specifies the mapping between the @@ -544,22 +556,22 @@ doc = "C(m, n) += A(m, k) * B(k, n)", indexing_maps = #matmul_accesses, library_call = "linalg_matmul", - args_in = 2, - args_out = 1, iterator_types = ["parallel", "parallel", "reduction"] } ``` And can be reused in multiple places as: ```mlir - linalg.generic #matmul_trait %A, %B, %C [other-attributes] { + linalg.generic #matmul_trait + ins(%A, %B : memref, + memref) + outs(%C : memref) + [other-attributes] { ^bb0(%a: f32, %b: f32, %c: f32) : %d = mulf %a, %b: f32 %e = addf %c, %d: f32 linalg.yield %e : f32 - } : memref, - memref, - memref + } ``` This may lower to either: @@ -588,30 +600,29 @@ ``` To allow progressive lowering from the value world (a.k.a tensor values) to - the buffer world (a.k.a memref values), a `linalg.generic` op accepts - mixing input and output ranked tensor values with input and output memrefs. + the buffer world (a.k.a memref values), a `linalg.generic` op allows mixing + tensors and buffers operands and tensor results. ```mlir - %C = linalg.generic #trait_attribute %A, %B {other-attributes} {region} : - tensor, - memref + %C = linalg.generic #trait_attribute + ins(%A, %B : tensor, memref) + init(%C : tensor) + [other-attributes] + {region} -> (tensor) ``` - In this case, the number of outputs (args_out) must match the sum of (1) the - number of output buffer operands and (2) the number of tensor return values. - The semantics is that the `linalg.indexed_generic` op produces (i.e. - allocates and fills) its tensor return values. + The `init` operand and the conventions around mixing tensors and buffers are + described in more detail in the "Tensors and Buffers: Conventions and + Limitations" section in the [Linalg Document](../docs/Linalg.md) Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. Such legalization moves tensor return values + transformations can be applied. Such legalizations move tensor return values into output buffer operands and updates the region arguments accordingly. - Transformations that create control-flow around linalg.indexed_generic - operations are not expected to work with tensors because SSA values do not - escape naturally. Still, transformations and rewrites that take advantage of - tensor SSA values are expected to be useful and will be added in the near - future. + The `symbol_source` attribute allows selecting a particular operand and + introducing symbols for each operand dimension. Such symbols can then be + used in the indexing maps. Example of 1D convolution with symbols: ```mlir @@ -629,28 +640,20 @@ symbol_source = 1 } - linalg.generic #conv_1d_trait %in, %filter, %out { + linalg.generic #conv_1d_trait + ins(%in, %filter : memref, memref) + outs(%out : memref) { ^bb0(%a: f32, %b: f32, %c: f32) : %d = mulf %a, %b : f32 %e = addf %c, %d : f32 linalg.yield %e : f32 - } : memref, - memref, - memref + } ``` where symbol s0 will be substituted with `dim %filter, %c0` i.e. the first and only dimension of the second operand as specified by the symbol_source attribute. }]; - let builders = [ - OpBuilder< - "OpBuilder &builder, OperationState &result, ArrayRef resultTypes, " - "ValueRange args, int64_t argsIn, int64_t argsOut, " - "ArrayRef indexingMaps, ArrayRef iteratorTypes, " - "function_ref = nullptr"> - ]; - let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; @@ -662,19 +665,19 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { let description = [{ Indexed Generic Linalg op form where the key properties of the computation - are specified as attributes. In pretty form, a linalg.indexed_generic op is - written as: + are specified as attributes. In pretty form, a `linalg.indexed_generic` op + is written as: ```mlir - linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} : - memref, - memref, - memref + linalg.indexed_generic #trait_attribute + ins(%A, %B : memref, + memref) + outs(%C : memref) + [other-attributes] + {region} ``` Where #trait_attributes is an alias of a dictionary attribute containing: - - args_in: an I64Attr representing the number of input (readonly) views - - args_out: an I64Attr representing the number of output (readwrite) views - doc [optional]: a documentation string - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input and output view. Such AffineMapAttr specifies the mapping between the @@ -702,8 +705,6 @@ doc = "C(m, n) += A(m, k) * B(k, n)", indexing_maps = #matmul_accesses, library_call = "linalg_matmul", - args_in = 2, - args_out = 1, iterator_types = ["parallel", "parallel", "reduction"] } ``` @@ -711,23 +712,25 @@ And can be reused in multiple places as: ```mlir - linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] { + linalg.indexed_generic #matmul_trait + ins(%A, %B : memref, + memref) + outs(%C : memref) (%offset_m: index, %offset_n: index, %offset_k: index, %a: f32, %b: f32, %c: f32) : "some_optional_computation"(%offset_m, %offset_n, %offset_k) %d = mulf %a, %b: f32 %e = addf %c, %d: f32 linalg_yield %e : f32 - } : memref, - memref, - memref + } ``` This may lower to either: ```mlir call @linalg_matmul(%offset_m, %offset_n, %offset_k, %A, %B, %C) : - (memref, + (index, index, index, + memref, memref, memref) -> () @@ -753,41 +756,58 @@ To allow progressive lowering from the value world (a.k.a tensor values) to the buffer world (a.k.a memref values), a `linalg.indexed_generic` op - accepts mixing input and output ranked tensor values with input and output - memrefs. + allows mixing tensors and buffers operands and tensor results. ```mlir - %C = linalg.indexed_generic #trait_attribute %A, %B {other-attributes} - : tensor, - memref + %C = linalg.indexed_generic #trait_attribute + ins(%A, %B : tensor, memref) + init(%C : tensor) + [other-attributes] + {region_with_index_arguments} -> (tensor) ``` - In this case, the number of outputs (args_out) must match the sum of (1) the - number of output buffer operands and (2) the number of tensor return values. - The semantics is that the `linalg.indexed_generic` op produces (i.e. - allocates and fills) its return values. + The `init` operand and the conventions around mixing tensors and buffers are + described in more detail in the "Tensors and Buffers: Conventions and + Limitations" section in the [Linalg Document](../docs/Linalg.md) Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. Such legalization moves tensor return values - into output buffer operands and updates the region argument accordingly. - - Transformations that create control-flow around linalg.indexed_generic - operations are not expected to work with tensors because SSA values do not - escape naturally. Still, transformations and rewrites that take advantage of - tensor SSA values are expected to be useful and will be added in the near - future. - }]; + transformations can be applied. Such legalizations move tensor return values + into output buffer operands and updates the region arguments accordingly. - let builders = [ - OpBuilder< - "OpBuilder &builder, OperationState &result, ArrayRef resultTypes, " - "ValueRange args, int64_t argsIn, int64_t argsOut, " - "ArrayRef indexingMaps, ArrayRef iteratorTypes, " - "function_ref " - "= nullptr"> - ]; + The `symbol_source` attribute allows selecting a particular operand and + introducing symbols for each operand dimension. Such symbols can then be + used in the indexing maps. + + Example of 1D convolution with symbols: + ```mlir + #conv_1d_accesses = [ + affine_map<(m, n)[dimN] -> (m + n - dimN floordiv 2)>, // in + affine_map<(m, n)[dimN] -> (n)>, // filter + affine_map<(m, n)[dimN] -> (m)> // out + ] + #conv_1d_trait = { + doc = "O(m) += I(m + n - size(n) floordiv 2) * K(n)", + indexing_maps = #conv_1d_accesses, + library_call = "linalg_conv_1d", + iterator_types = ["parallel", "parallel"], + symbol_source = 1 + } + + linalg.generic #conv_1d_trait + ins(%in, %filter : memref, memref) + outs(%out : memref) { + ^bb0(%a: f32, %b: f32, %c: f32) : + %d = mulf %a, %b : f32 + %e = addf %c, %d : f32 + linalg.yield %e : f32 + } + ``` + where symbol s0 will be substituted with `dim %filter, %c0` i.e. the first + and only dimension of the second operand as specified by the symbol_source + attribute. + }]; let verifier = [{ return ::verify(*this); }]; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -90,7 +90,7 @@ unsigned getNumOutputs() { ConcreteType concreteOp = cast(this->getOperation()); return concreteOp.output_buffers().size() + - concreteOp.output_tensors().size(); + concreteOp.result_tensors().size(); } static LogicalResult verifyTrait(Operation *op) { ConcreteType concreteOp = cast(op); diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -58,14 +58,6 @@ /// op's iterators. constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; } -/// Attribute name for the IntegerAttr which encodes the number of input buffer -/// arguments. -constexpr StringRef getArgsInAttrName() { return "args_in"; } - -/// Attribute name for the IntegerAttr which encodes the number of input buffer -/// arguments. -constexpr StringRef getArgsOutAttrName() { return "args_out"; } - /// Attribute name for the StringAttr which encodes an optional documentation /// string of the structured op. constexpr StringRef getDocAttrName() { return "doc"; } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -71,7 +71,7 @@ return llvm::None; // Make sure this is reduction with one input and one output. - if (genericOp.args_in() != 1 || genericOp.args_out() != 1) + if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1) return llvm::None; auto originalInputType = op->getOperand(0).getType().cast(); diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -23,37 +23,36 @@ Operation *mlir::edsc::makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, - ArrayRef outputs, + ArrayRef outputBuffers, ArrayRef initTensors, + ArrayRef resultTensorTypes, function_ref regionBuilder, ArrayRef otherValues, ArrayRef otherAttributes) { - for (unsigned i = 0, e = outputs.size(); i + 1 < e; ++i) - assert(!(outputs[i].getType().isa() && - outputs[i + 1].getType().isa()) && - "output tensors must be passed after output buffers"); - auto &builder = edsc::ScopedContext::getBuilderRef(); - auto *ctx = builder.getContext(); - unsigned nInputs = inputs.size(); - unsigned nOutputs = outputs.size(); + OpBuilder &builder = edsc::ScopedContext::getBuilderRef(); + // Build maps SmallVector, 4> exprsList; - exprsList.reserve(nInputs + nOutputs); - for (auto structuredIndexed : inputs) - exprsList.emplace_back(structuredIndexed.getExprs().begin(), - structuredIndexed.getExprs().end()); - for (auto structuredIndexed : outputs) - exprsList.emplace_back(structuredIndexed.getExprs().begin(), - structuredIndexed.getExprs().end()); + exprsList.reserve(inputs.size() + outputBuffers.size() + initTensors.size()); + for (auto container : {inputs, outputBuffers, resultTensorTypes}) + for (const StructuredIndexed &s : container) + exprsList.emplace_back(s.getExprs().begin(), s.getExprs().end()); auto maps = AffineMap::inferFromExprList(exprsList); - unsigned nViews = nInputs + nOutputs; - SmallVector values; - values.reserve(nViews); - values.append(inputs.begin(), inputs.end()); - std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(values), - [](StructuredIndexed s) { return s.hasValue(); }); SmallVector types; - std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(types), - [](StructuredIndexed s) { return !s.hasValue(); }); + assert(llvm::all_of(resultTensorTypes, [](const StructuredIndexed &s) { + return !s.hasValue(); + })); + std::copy(resultTensorTypes.begin(), resultTensorTypes.end(), + std::back_inserter(types)); + + SmallVector inputValues, outputBufferValues, initTensorValues; + inputValues.reserve(inputs.size()); + outputBufferValues.reserve(outputBuffers.size()); + initTensorValues.reserve(initTensors.size()); + std::copy(inputs.begin(), inputs.end(), std::back_inserter(inputValues)); + std::copy(outputBuffers.begin(), outputBuffers.end(), + std::back_inserter(outputBufferValues)); + std::copy(initTensors.begin(), initTensors.end(), + std::back_inserter(initTensorValues)); auto iteratorStrTypes = llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString)); @@ -63,9 +62,9 @@ .create( edsc::ScopedContext::getLocation(), types, - values, - IntegerAttr::get(IntegerType::get(64, ctx), nInputs), - IntegerAttr::get(IntegerType::get(64, ctx), nOutputs), + inputValues, + outputBufferValues, + initTensorValues, builder.getAffineMapArrayAttr(maps), builder.getStrArrayAttr(iteratorStrTypes), StringAttr() /*doc*/, @@ -78,11 +77,12 @@ using namespace edsc; SmallVector blockTypes; - blockTypes.reserve(values.size()); - for (auto it : llvm::enumerate(values)) - blockTypes.push_back((it.index() < nViews) - ? getElementTypeOrSelf(it.value()) - : it.value().getType()); + blockTypes.reserve(inputs.size() + outputBuffers.size() + initTensors.size()); + for (auto container : {inputs, outputBuffers}) + for (const StructuredIndexed &s : container) + blockTypes.push_back(getElementTypeOrSelf(s.getType())); + for (Value v : initTensors) + blockTypes.push_back(getElementTypeOrSelf(v.getType())); assert(op->getNumRegions() == 1); assert(op->getRegion(0).empty()); @@ -113,20 +113,17 @@ UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) { SmallVector iterTypes(O.getExprs().size(), IteratorType::Parallel); - if (O.getType().isa()) { - auto fun = [&unaryOp](ValueRange args) { - assert(args.size() == 1 && "expected 1 block arguments"); - Value a(args[0]); - linalg_yield(unaryOp(a)); - }; - return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); - } auto fun = [&unaryOp](ValueRange args) { - assert(args.size() == 2 && "expected 2 block arguments"); + assert(args.size() >= 1 && "expected >= 1 block arguments"); Value a(args[0]); linalg_yield(unaryOp(a)); }; - return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); + if (O.getType().isa()) + return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputBuffers=*/{}, + /*initTensors=*/{}, /*resultTensorTypes=*/{O}, + fun); + return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputBuffers=*/{O}, + /*initTensors=*/{}, /*resultTensorTypes=*/{}, fun); } Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I, @@ -141,20 +138,18 @@ StructuredIndexed I2, StructuredIndexed O) { SmallVector iterTypes(O.getExprs().size(), IteratorType::Parallel); - if (O.getType().isa()) { - auto fun = [&binaryOp](ValueRange args) { - assert(args.size() == 2 && "expected 2 block arguments"); - Value a(args[0]), b(args[1]); - linalg_yield(binaryOp(a, b)); - }; - return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); - } auto fun = [&binaryOp](ValueRange args) { - assert(args.size() == 3 && "expected 3 block arguments"); + assert(args.size() >= 2 && "expected >= 1 block arguments"); Value a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); }; - return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); + if (O.getType().isa()) + return makeGenericLinalgOp( + iterTypes, /*inputs=*/{I1, I2}, /*outputBuffers=*/{}, + /*initTensors=*/{}, /*resultTensorTypes=*/{O}, fun); + return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2}, + /*outputBuffers=*/{O}, + /*initTensors=*/{}, /*resultTensorTypes=*/{}, fun); } Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1, @@ -185,23 +180,10 @@ StructuredIndexed A(vA), B(vB), C(vC); return makeGenericLinalgOp( {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, - {A({m, k}), B({k, n})}, - {C({m, n})}, - regionBuilder); - // clang-format on -} - -Operation * -mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC, - MatmulRegionBuilder regionBuilder) { - // clang-format off - AffineExpr m, n, k; - bindDims(ScopedContext::getContext(), m, n, k); - StructuredIndexed A(vA), B(vB), C(tC); - return makeGenericLinalgOp( - {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, - {A({m, k}), B({k, n})}, - {C({m, n})}, + /*inputs=*/{A({m, k}), B({k, n})}, + /*outputBuffers=*/{C({m, n})}, + /*initTensors=*/{}, + /*resultTensorTypes=*/{}, regionBuilder); // clang-format on } @@ -216,8 +198,10 @@ StructuredIndexed A(vA), B(vB), C(vC), D(tD); return makeGenericLinalgOp( {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, - {A({m, k}), B({k, n}), C({m, n})}, - {D({m, n})}, + /*inputs=*/{A({m, k}), B({k, n})}, + /*outputBuffers=*/{}, + /*initTensors=*/{C({m, n})}, + /*resultTensorTypes=*/{D({m, n})}, regionBuilder); // clang-format on } @@ -243,15 +227,18 @@ StructuredIndexed I(vI), W(vW), O(vO); // clang-format off return makeGenericLinalgOp( - {par, par, par, par, red, red, red}, { + {par, par, par, par, red, red, red}, + /*inputs=*/{ I({b, // Roundtrip to flattened form to serve as canonicalization and ensure // consistent ordering of subexpressions. simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), c}), - W({kh, kw, c, f})}, { - O({b, h, w, f})}, + W({kh, kw, c, f}) }, + /*outputBuffers=*/{ O({b, h, w, f}) }, + /*initTensors=*/{}, + /*resultTensorTypes=*/{}, macRegionBuilder); // clang-format on } @@ -276,15 +263,19 @@ unsigned numDims = kw.cast().getPosition() + 1; StructuredIndexed I(vI), W(vW), O(vO); return makeGenericLinalgOp( - {par, par, par, par, par, red, red}, { + {par, par, par, par, par, red, red}, + /*inputs=*/{ I({b, // Roundtrip to flattened form to serve as canonicalization and ensure // consistent ordering of subexpressions. simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), c}), - W({kh, kw, c, dm})}, { + W({kh, kw, c, dm})}, + /*outputBuffers=*/{ O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})}, + /*initTensors=*/{}, + /*resultTensorTypes=*/{}, macRegionBuilder); // clang-format on } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -40,6 +40,12 @@ TypeRange outputBufferTypes, TypeRange initTensorTypes, TypeRange resultTypes); +static ParseResult +parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, + SmallVectorImpl &inputsTypes, + SmallVectorImpl &outputBuffersTypes, + SmallVectorImpl &initTensorsTypes); + template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, @@ -53,6 +59,10 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result); +template +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op); + static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes); @@ -87,24 +97,52 @@ //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// +void GenericOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange outputBuffers, ArrayRef indexingMaps, + ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, + IntegerAttr symbolSource, + function_ref bodyBuild) { + build(builder, result, ArrayRef{}, inputs, outputBuffers, ValueRange{}, + builder.getAffineMapArrayAttr(indexingMaps), + builder.getStrArrayAttr(iteratorTypes), + doc.empty() ? StringAttr() : builder.getStringAttr(doc), + libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), + symbolSource); + if (!bodyBuild) + return; + + SmallVector blockArgTypes; + for (ValueRange container : {inputs, outputBuffers}) + for (Value v : container) + blockArgTypes.push_back(v.getType().cast().getElementType()); + + OpBuilder::InsertionGuard guard(builder); + auto ®ion = *result.regions.front(); + Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); + bodyBuild(builder, result.location, bodyBlock->getArguments()); +} void GenericOp::build( - OpBuilder &builder, OperationState &result, ArrayRef resultTypes, - ValueRange args, int64_t argsIn, int64_t argsOut, + OpBuilder &builder, OperationState &result, + ArrayRef resultTensorTypes, ValueRange inputs, + ValueRange outputBuffers, ValueRange initTensors, ArrayRef indexingMaps, ArrayRef iteratorTypes, + StringRef doc, StringRef libraryCall, IntegerAttr symbolSource, function_ref bodyBuild) { - build(builder, result, resultTypes, args, builder.getI64IntegerAttr(argsIn), - builder.getI64IntegerAttr(argsOut), + build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors, builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), - /*doc=*/nullptr, /*library_call=*/nullptr, - /*symbol_source=*/nullptr); + doc.empty() ? StringAttr() : builder.getStringAttr(doc), + libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), + symbolSource); if (!bodyBuild) return; SmallVector blockArgTypes; - for (Value arg : args) - blockArgTypes.push_back(arg.getType().cast().getElementType()); + for (ValueRange container : {inputs, outputBuffers, initTensors}) + for (Value v : container) + blockArgTypes.push_back(v.getType().cast().getElementType()); OpBuilder::InsertionGuard guard(builder); auto ®ion = *result.regions.front(); @@ -113,53 +151,99 @@ } void IndexedGenericOp::build( - OpBuilder &builder, OperationState &result, ArrayRef resultTypes, - ValueRange args, int64_t argsIn, int64_t argsOut, + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange outputBuffers, ArrayRef indexingMaps, + ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, + IntegerAttr symbolSource, + function_ref bodyBuild) { + build(builder, result, ArrayRef{}, inputs, outputBuffers, ValueRange{}, + builder.getAffineMapArrayAttr(indexingMaps), + builder.getStrArrayAttr(iteratorTypes), + doc.empty() ? StringAttr() : builder.getStringAttr(doc), + libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), + symbolSource); + if (!bodyBuild) + return; + + unsigned nLoops = iteratorTypes.size(); + SmallVector blockArgTypes(nLoops, builder.getIndexType()); + for (ValueRange container : {inputs, outputBuffers}) + for (Value v : container) + blockArgTypes.push_back(v.getType().cast().getElementType()); + + OpBuilder::InsertionGuard guard(builder); + auto ®ion = *result.regions.front(); + Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); + bodyBuild(builder, result.location, bodyBlock->getArguments()); +} + +void IndexedGenericOp::build( + OpBuilder &builder, OperationState &result, + ArrayRef resultTensorTypes, ValueRange inputs, + ValueRange outputBuffers, ValueRange initTensors, ArrayRef indexingMaps, ArrayRef iteratorTypes, - function_ref - bodyBuild) { - build(builder, result, resultTypes, args, builder.getI64IntegerAttr(argsIn), - builder.getI64IntegerAttr(argsOut), + StringRef doc, StringRef libraryCall, IntegerAttr symbolSource, + function_ref bodyBuild) { + build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors, builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), - /*doc=*/nullptr, /*library_call=*/nullptr, - /*symbol_source=*/nullptr); + doc.empty() ? StringAttr() : builder.getStringAttr(doc), + libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), + symbolSource); if (!bodyBuild) return; unsigned nLoops = iteratorTypes.size(); SmallVector blockArgTypes(nLoops, builder.getIndexType()); - for (Value arg : args) - blockArgTypes.push_back(arg.getType().cast().getElementType()); + for (ValueRange container : {inputs, outputBuffers, initTensors}) + for (Value v : container) + blockArgTypes.push_back(v.getType().cast().getElementType()); OpBuilder::InsertionGuard guard(builder); auto ®ion = *result.regions.front(); Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); - bodyBuild(builder, result.location, - bodyBlock->getArguments().take_front(nLoops), - bodyBlock->getArguments().drop_front(nLoops)); + bodyBuild(builder, result.location, bodyBlock->getArguments()); } template static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { - auto attrNames = op.linalgTraitAttrNames(); - llvm::StringSet<> linalgTraitAttrsSet; - linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end()); - SmallVector attrs; + p << op.getOperationName() << " "; + + // Print extra attributes. + auto genericAttrNames = op.linalgTraitAttrNames(); + + llvm::StringSet<> genericAttrNamesSet; + genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); + SmallVector genericAttrs; for (auto attr : op.getAttrs()) - if (linalgTraitAttrsSet.count(attr.first.strref()) > 0) - attrs.push_back(attr); + if (genericAttrNamesSet.count(attr.first.strref()) > 0) + genericAttrs.push_back(attr); + if (!genericAttrs.empty()) { + auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext()); + p << genericDictAttr; + } + + // Printing is shared with named ops, except for the region and attributes + printCommonStructuredOpParts(p, op); - auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); - p << op.getOperationName() << " " << dictAttr; - p.printOptionalAttrDict(op.getAttrs(), attrNames); - p << " " << op.getOperands(); + genericAttrNames.push_back("operand_segment_sizes"); + + bool hasExtraAttrs = false; + for (NamedAttribute n : op.getAttrs()) { + if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.first.strref()))) + break; + } + if (hasExtraAttrs) { + p << " attrs = "; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/genericAttrNames); + } + + // Print region. if (!op.region().empty()) p.printRegion(op.region()); - p << ": " << op.getOperandTypes(); - auto outputTensorTypes = op.getResultTypes(); - if (!outputTensorTypes.empty()) - p << " -> " << outputTensorTypes; + + // Print results. + printNamedStructuredOpResults(p, op.result_tensors().getTypes()); } static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); } @@ -169,7 +253,6 @@ } static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) { - SmallVector operandsInfo, regionOperandsInfo; DictionaryAttr dictAttr; // Parse the core linalg traits that must check into a dictAttr. // The name is unimportant as we will overwrite result.attributes. @@ -180,26 +263,35 @@ result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); - // Optional attributes may be added. - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseOperandList(operandsInfo)) + // Parsing is shared with named ops, except for the region. + SmallVector inputsTypes, outputBuffersTypes, initTensorsTypes; + if (parseCommonStructuredOpParts(parser, result, inputsTypes, + outputBuffersTypes, initTensorsTypes)) return failure(); - Region ®ion = *result.addRegion(); + // Optional attributes may be added. + if (succeeded(parser.parseOptionalKeyword("attrs"))) + if (failed(parser.parseEqual()) || + failed(parser.parseOptionalAttrDict(result.attributes))) + return failure(); + + SmallVector regionOperands; + std::unique_ptr region = std::make_unique(); SmallVector operandTypes, regionTypes; - if (parser.parseRegion(region, regionOperandsInfo, regionTypes)) - return failure(); - if (parser.parseColonTypeList(operandTypes)) + if (parser.parseRegion(*region, regionOperands, regionTypes)) return failure(); + result.addRegion(std::move(region)); + // Generic ops may specify that a subset of its outputs are tensors. Such // outputs are specified in the result type. - SmallVector tensorResultTypes; - if (parser.parseOptionalArrowTypeList(tensorResultTypes)) + // TODO: may need to move output parsing before region parsing. + // Need to wait for declarative assembly resolution to decide. + SmallVector outputTensorsTypes; + if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) return failure(); - if (!tensorResultTypes.empty()) - result.addTypes(tensorResultTypes); - return parser.resolveOperands(operandsInfo, operandTypes, - parser.getCurrentLocation(), result.operands); + result.addTypes(outputTensorsTypes); + + return success(); } namespace { @@ -265,6 +357,11 @@ auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); + if (op.inputs().size() + op.output_buffers().size() + + op.init_tensors().size() + op.getNumResults() == + 0) + return op.emitOpError("expected at least 1 Shaped operand or return"); + auto ®ion = op.region(); if (!llvm::hasSingleElement(region)) return op.emitOpError("expected region with 1 block"); @@ -313,27 +410,9 @@ return success(); } -static LogicalResult verify(GenericOp op) { - // Temporarily hoisted here to avoid duplicating more code. - // TODO: uniformize with named structured ops. - auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); - if (nInputsAndOutputBuffers != llvm::size(op.views())) - return op.emitOpError("expected exactly ") - << nInputsAndOutputBuffers - << " inputs (tensor or buffer) and output buffer operands"; - return verifyGenericOp(op); -} +static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } -static LogicalResult verify(IndexedGenericOp op) { - // Temporarily hoisted here to avoid duplicating more code. - // TODO: uniformize with named structured ops. - auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); - if (nInputsAndOutputBuffers != llvm::size(op.views())) - return op.emitOpError("expected exactly ") - << nInputsAndOutputBuffers - << " inputs (tensor or buffer) and output buffer operands"; - return verifyGenericOp(op); -} +static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } //===----------------------------------------------------------------------===// // ReshapeOp @@ -1140,6 +1219,9 @@ /// Assumes `op` is a LinalgOp. void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, SmallVectorImpl &res) { + if (!cast(op).iterator_types()) + return; + unsigned dim = 0; MLIRContext *ctx = op->getContext(); for (auto tn : @@ -1340,25 +1422,27 @@ return success(); } -template -static ParseResult parseNamedStructuredOp(OpAsmParser &parser, - OperationState &result) { +static ParseResult +parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, + SmallVectorImpl &inputsTypes, + SmallVectorImpl &outputBuffersTypes, + SmallVectorImpl &initTensorsTypes) { llvm::SMLoc inputsOperandsLoc, outputBuffersOperandsLoc, initTensorsOperandsLoc; SmallVector inputsOperands, outputBuffersOperands, initTensorsOperands; - SmallVector inputsTypes, outputBuffersTypes, initTensorsTypes, - outputTensorsTypes; - std::unique_ptr regionRegion = std::make_unique(); - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseKeyword("ins") || parser.parseLParen()) - return failure(); + parser.parseOptionalAttrDict(result.attributes); - inputsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(inputsOperands) || parser.parseColon() || - parser.parseTypeList(inputsTypes) || parser.parseRParen()) - return failure(); + if (succeeded(parser.parseOptionalKeyword("ins"))) { + if (parser.parseLParen()) + return failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands) || parser.parseColon() || + parser.parseTypeList(inputsTypes) || parser.parseRParen()) + return failure(); + } if (succeeded(parser.parseOptionalKeyword("outs"))) { outputBuffersOperandsLoc = parser.getCurrentLocation(); @@ -1375,14 +1459,6 @@ return failure(); } - if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) - return failure(); - - if (parseNamedStructuredOpRegion( - parser, *regionRegion, inputsTypes, outputBuffersTypes, - initTensorsTypes, outputTensorsTypes)) - return failure(); - if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc, result.operands) || parser.resolveOperands(outputBuffersOperands, outputBuffersTypes, @@ -1391,8 +1467,6 @@ initTensorsOperandsLoc, result.operands)) return failure(); - result.addTypes(outputTensorsTypes); - result.addRegion(std::move(regionRegion)); result.addAttribute("operand_segment_sizes", parser.getBuilder().getI32VectorAttr( {static_cast(inputsOperands.size()), @@ -1401,6 +1475,31 @@ return success(); } +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result) { + SmallVector inputsTypes, outputBuffersTypes, initTensorsTypes; + if (parseCommonStructuredOpParts(parser, result, inputsTypes, + outputBuffersTypes, initTensorsTypes)) + return failure(); + + // TODO: consider merging results parsing into region parsing. + // Need to wait for declarative assembly resolution to decide. + SmallVector outputTensorsTypes; + if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) + return failure(); + result.addTypes(outputTensorsTypes); + + std::unique_ptr region = std::make_unique(); + if (parseNamedStructuredOpRegion( + parser, *region, inputsTypes, outputBuffersTypes, initTensorsTypes, + outputTensorsTypes)) + return failure(); + result.addRegion(std::move(region)); + + return success(); +} + static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes) { if (resultTypes.empty()) @@ -1409,20 +1508,28 @@ } template -static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { - p << op.getOperationName(); - p.printOptionalAttrDict(op.getAttrs(), - /*elidedAttrs=*/{"operand_segment_sizes"}); +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op) { p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; if (!op.output_buffers().empty()) p << " outs(" << op.output_buffers() << " : " << op.output_buffers().getTypes() << ")"; if (!op.init_tensors().empty()) p << " init(" << op.init_tensors() << " : " << op.init_tensors().getTypes() - << ")"; - p << " "; - printNamedStructuredOpResults(p, op.output_tensors().getTypes()); - p << " "; + << ") "; +} + +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { + p << op.getOperationName(); + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{"operand_segment_sizes"}); + + // Printing is shared with generic ops, except for the region and attributes. + printCommonStructuredOpParts(p, op); + + // Results printing. + printNamedStructuredOpResults(p, op.result_tensors().getTypes()); // Region is elided. } diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -261,12 +261,14 @@ } namespace { + /// Pattern to replace tensors operands/results that are unit extents. struct ReplaceUnitExtentTensors : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!genericOp.hasTensorSemantics()) + // TODO: support init_tensors and reductions. + if (!genericOp.hasTensorSemantics() || !genericOp.init_tensors().empty()) return failure(); MLIRContext *context = rewriter.getContext(); @@ -295,17 +297,28 @@ // If any operand type change, insert a reshape to convert from the original // type to the new type. - SmallVector newOperands; - newOperands.reserve(genericOp.getNumOperands()); - for (auto operand : llvm::enumerate(genericOp.getOperands())) { - if (operand.value().getType() == newInputOutputTypes[operand.index()]) { - newOperands.push_back(operand.value()); - } else { - newOperands.push_back(rewriter.create( - loc, newInputOutputTypes[operand.index()], operand.value(), - reassociationMaps[operand.index()])); + // TODO: get rid of flattenedIdx which assumes operand order and contiguity. + unsigned flattenedIdx = 0; + auto insertReshapes = [&](ValueRange values) { + SmallVector res; + res.reserve(values.size()); + for (auto operand : llvm::enumerate(values)) { + if (operand.value().getType() == newInputOutputTypes[flattenedIdx]) + res.push_back(operand.value()); + else + res.push_back(rewriter.create( + loc, newInputOutputTypes[flattenedIdx], operand.value(), + reassociationMaps[flattenedIdx])); + ++flattenedIdx; } - } + return res; + }; + + SmallVector newInputs = insertReshapes(genericOp.inputs()); + SmallVector newOutputBuffers = + insertReshapes(genericOp.output_buffers()); + SmallVector newInitTensors = + insertReshapes(genericOp.init_tensors()); // If any result type change, insert a reshape to convert from the original // type to the new type. @@ -315,8 +328,8 @@ resultTypes.push_back( newInputOutputTypes[i + genericOp.getNumOperands()]); GenericOp replacementOp = rewriter.create( - loc, resultTypes, newOperands, genericOp.args_in(), - genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps), + loc, resultTypes, newInputs, newOutputBuffers, newInitTensors, + rewriter.getAffineMapArrayAttr(newIndexingMaps), genericOp.iterator_types(), /*doc = */ nullptr, /*library_call = */ nullptr, @@ -332,12 +345,11 @@ RankedTensorType origResultType = genericOp.getResult(result.index()) .getType() .cast(); - if (origResultType != result.value().getType()) { + if (origResultType != result.value().getType()) resultReplacements.push_back(rewriter.create( loc, origResultType, result.value(), reassociationMaps[index])); - } else { + else resultReplacements.push_back(result.value()); - } } rewriter.replaceOp(genericOp, resultReplacements); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -443,6 +443,10 @@ if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) return false; + // TODO: maybe allow init_tensors and reductions. + // if (producer.init_tensors().empty() || consumer.init_tensors().empty()) + // return false; + // Verify that // - the producer has all "parallel" iterator type. if (producer.getNumParallelLoops() != producer.getNumLoops()) @@ -499,38 +503,37 @@ consumerIndexMaps.end()); // Generate the fused op. + // Tensor-level fusion is only on ops without initTensors and outputBuffers. LinalgOp fusedOp; if (isa(producer.getOperation()) && isa(consumer.getOperation())) { fusedOp = rewriter - .create( - rewriter.getUnknownLoc(), - consumer.getOperation()->getResultTypes(), fusedOperands, - rewriter.getI64IntegerAttr(fusedOperands.size()), - rewriter.getI64IntegerAttr( - consumer.getOperation()->getNumResults()), - rewriter.getArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*symbol_source=*/nullptr) + .create(rewriter.getUnknownLoc(), + consumer.getOperation()->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, + rewriter.getArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr, + /*symbol_source=*/nullptr) .getOperation(); } else { - fusedOp = - rewriter - .create( - rewriter.getUnknownLoc(), - consumer.getOperation()->getResultTypes(), fusedOperands, - rewriter.getI64IntegerAttr(fusedOperands.size()), - rewriter.getI64IntegerAttr( - consumer.getOperation()->getNumResults()), - rewriter.getArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*symbol_source=*/nullptr) - .getOperation(); + fusedOp = rewriter + .create( + rewriter.getUnknownLoc(), + consumer.getOperation()->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, + rewriter.getArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr, + /*symbol_source=*/nullptr) + .getOperation(); } // Construct an AffineMap from consumer loops to producer loops. @@ -812,9 +815,10 @@ })); LinalgOp fusedOp = createLinalgOpOfSameType( consumer, rewriter, rewriter.getUnknownLoc(), - consumerOp->getResultTypes(), fusedOperands, - rewriter.getI64IntegerAttr(fusedOperands.size()), - rewriter.getI64IntegerAttr(consumerOp->getNumResults()), + consumerOp->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, // no init tensors for now. rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr, @@ -871,10 +875,10 @@ Operation *producerOp = producer.getOperation(); LinalgOp fusedOp = createLinalgOpOfSameType( producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(), - producerOp->getOperands(), - rewriter.getI64IntegerAttr(producerOp->getNumOperands()), - rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs), - producer.iterator_types(), + /*inputs=*/producerOp->getOperands(), + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, // no init tensors for now. + rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr, /*symbol_source=*/nullptr); @@ -932,10 +936,10 @@ } int rank = dstShape.size(); - int numArgsIn = producer.getNumInputs(); - int numArgsOut = producer.getNumOutputs(); auto genericOp = rewriter.create( - loc, resultTypes, args, numArgsIn, numArgsOut, + loc, resultTypes, /*inputs=*/args, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, SmallVector(args.size() + resultTypes.size(), rewriter.getMultiDimIdentityMap(rank)), SmallVector(rank, getParallelIteratorTypeName())); @@ -995,9 +999,10 @@ LinalgOp fusedOp = createLinalgOpOfSameType( consumer, rewriter, rewriter.getUnknownLoc(), - consumerOp->getResultTypes(), fusedOperands, - rewriter.getI64IntegerAttr(consumerOp->getNumOperands() - 1), - rewriter.getI64IntegerAttr(consumerOp->getNumResults()), + consumerOp->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputBuffers=*/ValueRange{}, + /*initTensors=*/ValueRange{}, // no init tensors for now. rewriter.getAffineMapArrayAttr(fusedIndexMaps), consumer.iterator_types(), /*doc=*/nullptr, diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -36,32 +36,45 @@ LogicalResult matchAndRewrite(linalg::GenericOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { + linalg::GenericOpAdaptor adaptor(operands, + op.getOperation()->getAttrDictionary()); + + // TODO: support ops with reduction. + if (!op.init_tensors().empty()) + return failure(); + + // All inputs need to be turned into buffers first. Until then, bail out. + if (llvm::any_of(adaptor.inputs(), + [](Value in) { return !in.getType().isa(); })) + return failure(); + Location loc = op.getLoc(); - ResultRange results = op.getOperation()->getResults(); - SmallVector newArgs, newResults; - newArgs.reserve(operands.size() + results.size()); - newArgs.append(operands.begin(), operands.end()); - newResults.reserve(results.size()); + SmallVector outputBuffers, newOutputBuffers; + outputBuffers.assign(adaptor.output_buffers().begin(), + adaptor.output_buffers().end()); + newOutputBuffers.reserve(op.getNumOutputs()); + newOutputBuffers.append(adaptor.output_buffers().begin(), + adaptor.output_buffers().end()); // Update all types to memref types. - for (auto result : results) { - auto type = result.getType().cast(); - assert(type && "tensor to buffer conversion expects ranked results"); + for (Type t : op.getResultTypes()) { + auto type = t.cast(); if (!type.hasStaticShape()) return rewriter.notifyMatchFailure( op, "dynamic shapes not currently supported"); auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); auto alloc = rewriter.create(loc, memrefType); - newArgs.push_back(alloc); - newResults.push_back(alloc); + newOutputBuffers.push_back(alloc); } // Generate a new linalg operation that works on buffers. auto linalgOp = rewriter.create( - loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()), - rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(), - op.iterator_types(), op.docAttr(), op.library_callAttr(), - op.symbol_sourceAttr()); + loc, + /*resultTensorTypes=*/ArrayRef{}, + /*inputs=*/adaptor.inputs(), + /*outputBuffers=*/newOutputBuffers, + /*initTensors=*/ValueRange{}, op.indexing_maps(), op.iterator_types(), + op.docAttr(), op.library_callAttr(), op.symbol_sourceAttr()); // Create a new block in the region of the new Generic Op. Block &oldBlock = op.getRegion().front(); @@ -70,23 +83,23 @@ oldBlock.getArgumentTypes()); // Add the result arguments to the new block. - for (auto result : newResults) - newBlock->addArgument( - result.getType().cast().getElementType()); + for (Value v : newOutputBuffers) + newBlock->addArgument(v.getType().cast().getElementType()); // Clone the body of the old block to the new block. BlockAndValueMapping mapping; for (unsigned i = 0; i < oldBlock.getNumArguments(); i++) mapping.map(oldBlock.getArgument(i), newBlock->getArgument(i)); + + OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(newBlock); for (auto &op : oldBlock.getOperations()) { Operation *clonedOp = rewriter.clone(op, mapping); mapping.map(op.getResults(), clonedOp->getResults()); } - // Replace the results of the old Generic Op with the results of the new - // one. - rewriter.replaceOp(op, newResults); + // Replace the results of the old op with the new output buffers. + rewriter.replaceOp(op, newOutputBuffers); return success(); } }; diff --git a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir --- a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir +++ b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir @@ -5,8 +5,6 @@ //===----------------------------------------------------------------------===// #single_workgroup_reduction_trait = { - args_in = 1, - args_out = 1, iterator_types = ["reduction"], indexing_maps = [ affine_map<(i) -> (i)>, @@ -49,11 +47,13 @@ func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes { spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>} } { - linalg.generic #single_workgroup_reduction_trait %input, %output { + linalg.generic #single_workgroup_reduction_trait + ins(%input : memref<16xi32>) + outs(%output : memref<1xi32>) { ^bb(%in: i32, %out: i32): %sum = addi %in, %out : i32 linalg.yield %sum : i32 - } : memref<16xi32>, memref<1xi32> + } spv.Return } } @@ -63,8 +63,6 @@ // Missing shader entry point ABI #single_workgroup_reduction_trait = { - args_in = 1, - args_out = 1, iterator_types = ["reduction"], indexing_maps = [ affine_map<(i) -> (i)>, @@ -78,11 +76,13 @@ } { func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) { // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} - linalg.generic #single_workgroup_reduction_trait %input, %output { + linalg.generic #single_workgroup_reduction_trait + ins(%input : memref<16xi32>) + outs(%output : memref<1xi32>) { ^bb(%in: i32, %out: i32): %sum = addi %in, %out : i32 linalg.yield %sum : i32 - } : memref<16xi32>, memref<1xi32> + } return } } @@ -92,8 +92,6 @@ // Mismatch between shader entry point ABI and input memref shape #single_workgroup_reduction_trait = { - args_in = 1, - args_out = 1, iterator_types = ["reduction"], indexing_maps = [ affine_map<(i) -> (i)>, @@ -109,11 +107,13 @@ spv.entry_point_abi = {local_size = dense<[32, 1, 1]>: vector<3xi32>} } { // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} - linalg.generic #single_workgroup_reduction_trait %input, %output { + linalg.generic #single_workgroup_reduction_trait + ins(%input : memref<16xi32>) + outs(%output : memref<1xi32>) { ^bb(%in: i32, %out: i32): %sum = addi %in, %out : i32 linalg.yield %sum : i32 - } : memref<16xi32>, memref<1xi32> + } spv.Return } } @@ -123,8 +123,6 @@ // Unsupported multi-dimension input memref #single_workgroup_reduction_trait = { - args_in = 1, - args_out = 1, iterator_types = ["parallel", "reduction"], indexing_maps = [ affine_map<(i, j) -> (i, j)>, @@ -140,11 +138,13 @@ spv.entry_point_abi = {local_size = dense<[16, 8, 1]>: vector<3xi32>} } { // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} - linalg.generic #single_workgroup_reduction_trait %input, %output { + linalg.generic #single_workgroup_reduction_trait + ins(%input : memref<16x8xi32>) + outs(%output : memref<16xi32>) { ^bb(%in: i32, %out: i32): %sum = addi %in, %out : i32 linalg.yield %sum : i32 - } : memref<16x8xi32>, memref<16xi32> + } spv.Return } } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -182,8 +182,6 @@ ] #trait = { - args_in = 1, - args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel"] } @@ -193,10 +191,10 @@ linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32> // tensor<0xf32> cannot be dce'ed - %1 = linalg.generic #trait %arg1 { + %1 = linalg.generic #trait ins(%arg1 : tensor<0xf32>) { ^bb(%0: f32) : linalg.yield %0 : f32 - } : tensor<0xf32> -> tensor<0xf32> + } -> tensor<0xf32> return %1: tensor<0xf32> } diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -6,8 +6,6 @@ ] #trait = { - args_in = 1, - args_out = 1, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], indexing_maps = #accesses, library_call = "some_external_func" @@ -15,10 +13,11 @@ func @drop_one_trip_loops(%arg0 : tensor) -> tensor { - %0 = linalg.generic #trait %arg0 { + %0 = linalg.generic #trait + ins(%arg0 : tensor) { ^bb0(%arg1 : f32) : linalg.yield %arg1 : f32 - } : tensor -> tensor + } -> tensor return %0 : tensor } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -40,8 +39,6 @@ #map0 = affine_map<(i, j) -> (i, j)> #access = [#map0, #map0] #trait = { - args_in = 1, - args_out = 1, iterator_types = ["parallel", "parallel"], indexing_maps = #access, library_call = "some_external_func" @@ -49,10 +46,11 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32> { - %0 = linalg.generic #trait %arg0 { + %0 = linalg.generic #trait + ins(%arg0 : tensor<1x1xf32>) { ^bb0(%arg1: f32) : linalg.yield %arg1 : f32 - } : tensor<1x1xf32> -> tensor<1x1xf32> + } -> tensor<1x1xf32> return %0 : tensor<1x1xf32> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> ()> @@ -70,18 +68,17 @@ ] #trait = { - args_in = 1, - args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel"], library_call = "some_external_fn" } func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> { - %0 = linalg.generic #trait %arg0 { + %0 = linalg.generic #trait + ins(%arg0 : tensor<1x5xf32>) { ^bb0(%arg2: f32): // no predecessors linalg.yield %arg2 : f32 - } : tensor<1x5xf32> -> tensor<5xf32> + } -> tensor<5xf32> return %0 : tensor<5xf32> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> @@ -100,8 +97,6 @@ ] #trait = { - args_in = 2, - args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel", "parallel"], library_call = "some_external_fn" @@ -113,11 +108,12 @@ tensor<5xf32> into tensor<1x5xf32> %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<5xf32> into tensor<5x1xf32> - %2 = linalg.generic #trait %0, %1 { + %2 = linalg.generic #trait + ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>) { ^bb0(%arg2: f32, %arg3: f32): %3 = addf %arg2, %arg3 : f32 linalg.yield %3 : f32 - } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> + } -> tensor<5x5xf32> return %2 : tensor<5x5xf32> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)> @@ -138,8 +134,6 @@ ] #trait = { - args_in = 1, - args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel", "parallel"], library_call = "some_external_fn" @@ -147,10 +141,11 @@ func @broadcast_scalar(%arg0 : tensor<1x1xf32>) -> tensor { - %0 = linalg.generic #trait %arg0 { - ^bb0(%arg1 : f32): - linalg.yield %arg1 : f32 - } : tensor<1x1xf32> -> tensor + %0 = linalg.generic #trait + ins(%arg0 : tensor<1x1xf32>) { + ^bb0(%arg1 : f32): + linalg.yield %arg1 : f32 + } -> tensor return %0 : tensor } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()> diff --git a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir --- a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir +++ b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir @@ -6,8 +6,6 @@ ] #trait = { - args_in = 1, - args_out = 1, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], indexing_maps = #accesses, library_call = "some_external_func" @@ -15,10 +13,11 @@ func @drop_one_trip_loops(%arg0 : tensor) -> tensor { - %0 = linalg.generic #trait %arg0 { + %0 = linalg.generic #trait + ins(%arg0 : tensor) { ^bb0(%arg1 : f32) : linalg.yield %arg1 : f32 - } : tensor -> tensor + } -> tensor return %0 : tensor } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> @@ -33,8 +32,6 @@ #map0 = affine_map<(i, j) -> (i, j)> #access = [#map0, #map0] #trait = { - args_in = 1, - args_out = 1, iterator_types = ["parallel", "parallel"], indexing_maps = #access, library_call = "some_external_func" @@ -42,10 +39,11 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32> { - %0 = linalg.generic #trait %arg0 { + %0 = linalg.generic #trait + ins(%arg0 : tensor<1x1xf32>) { ^bb0(%arg1: f32) : linalg.yield %arg1 : f32 - } : tensor<1x1xf32> -> tensor<1x1xf32> + } -> tensor<1x1xf32> return %0 : tensor<1x1xf32> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)> @@ -59,8 +57,6 @@ #map0 = affine_map<(i, j) -> (i, j)> #access = [#map0, #map0] #trait = { - args_in = 1, - args_out = 1, iterator_types = ["parallel", "parallel"], indexing_maps = #access, library_call = "some_external_func" @@ -68,10 +64,12 @@ func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>) { - linalg.generic #trait %arg0, %arg1 { + linalg.generic #trait + ins(%arg0 : memref<1x1xf32>) + outs(%arg1 : memref<1x1xf32>) { ^bb0(%arg2: f32, %arg3 : f32) : linalg.yield %arg2 : f32 - } : memref<1x1xf32>, memref<1x1xf32> + } return } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)> @@ -88,18 +86,17 @@ ] #trait = { - args_in = 1, - args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel", "parallel"], library_call = "some_external_fn" } func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> { - %0 = linalg.generic #trait %arg0 { - ^bb0(%arg2: f32): // no predecessors - linalg.yield %arg2 : f32 - } : tensor<1x5xf32> -> tensor<5xf32> + %0 = linalg.generic #trait + ins(%arg0 : tensor<1x5xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5xf32> return %0 : tensor<5xf32> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (0, d0)> diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -6,14 +6,16 @@ // CHECK-LABEL: @add_mul_fusion func @add_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 { + %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 - }: tensor, tensor -> tensor - // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 + } -> tensor + // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}} - %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { + %2 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%0, %arg2 : tensor, tensor) { // CHECK: ^{{[a-zA-Z0-9_]*}} // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]] // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]] @@ -25,7 +27,7 @@ // CHECK: linalg.yield %3 = mulf %arg5, %arg6 : f32 linalg.yield %3 : f32 - }: tensor, tensor -> tensor + } -> tensor return %2 : tensor } @@ -39,18 +41,20 @@ // CHECK-LABEL: @transpose_add_mul_fusion func @transpose_add_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 { + %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 - }: tensor, tensor -> tensor - // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 + } -> tensor + // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP1]], [[$MAP0]], [[$MAP0]]{{\]}} - %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { + %2 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%0, %arg2 : tensor, tensor) { ^bb0(%arg5: f32, %arg6: f32): // no predecessors %3 = mulf %arg5, %arg6 : f32 linalg.yield %3 : f32 - }: tensor, tensor -> tensor + } -> tensor return %2 : tensor } @@ -64,18 +68,20 @@ // CHECK-LABEL: @add_transpose_mul_fusion func @add_transpose_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 { + %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 - }: tensor, tensor -> tensor - // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 + } -> tensor + // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = {{\[}}[[$MAP1]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}} - %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { + %2 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%0, %arg2 : tensor, tensor) { ^bb0(%arg5: f32, %arg6: f32): // no predecessors %3 = mulf %arg5, %arg6 : f32 linalg.yield %3 : f32 - }: tensor, tensor -> tensor + } -> tensor return %2 : tensor } @@ -90,18 +96,20 @@ // CHECK-LABEL: @add_broadcast_mul_fusion func @add_broadcast_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]} %arg0, %arg1 { + %0 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 - }: tensor, tensor -> tensor - // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 + } -> tensor + // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = {{\[}}[[$MAP1]], [[$MAP1]], [[$MAP0]], [[$MAP0]] - %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { + %2 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%0, %arg2 : tensor, tensor) { ^bb0(%arg5: f32, %arg6: f32): // no predecessors %3 = mulf %arg5, %arg6 : f32 linalg.yield %3 : f32 - }: tensor, tensor -> tensor + } -> tensor return %2 : tensor } @@ -113,19 +121,21 @@ // CHECK-LABEL: @add_mul_scalar_fusion func @add_mul_scalar_fusion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = []} %arg0, %arg1 { + %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []} + ins(%arg0, %arg1 : tensor, tensor) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 - }: tensor, tensor -> tensor - // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 + } -> tensor + // CHECK: linalg.generic { // CHECK: addf // CHECK: mulf - %1 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = []} %0, %arg2 { + %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []} + ins(%0, %arg2 : tensor, tensor) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %1 = mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 - }: tensor, tensor -> tensor + } -> tensor return %1 : tensor } @@ -144,25 +154,23 @@ affine_map<(i, j, k, l) -> (j, k)>, affine_map<(i, j, k, l) -> (l)>] : tensor into tensor - %1 = linalg.generic - {args_in = 2 : i64, args_out = 1 : i64, + %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - %0, %arg1 { + ins(%0, %arg1 : tensor, tensor) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %1 = mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 - }: tensor, tensor -> tensor + } -> tensor return %1 : tensor } // CHECK-LABEL: func @generic_op_reshape_producer_fusion // CHECK: linalg.generic -// CHECK-SAME: args_in = 2 -// CHECK-SAME: args_out = 1 // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]] // CHECK-NOT: linalg.generic + // ----- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> @@ -173,15 +181,14 @@ %arg1 : tensor) -> tensor { - %0 = linalg.generic - {args_in = 2 : i64, args_out = 1 : i64, + %0 = linalg.generic { indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - %arg0, %arg1 { + ins(%arg0, %arg1 : tensor, tensor) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %1 = mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 - }: tensor, tensor -> tensor + } -> tensor %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k, l)>] : tensor into tensor @@ -190,8 +197,6 @@ // CHECK-LABEL: func @generic_op_reshape_consumer_fusion // CHECK: linalg.generic -// CHECK-SAME: args_in = 2 -// CHECK-SAME: args_out = 1 // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]] // CHECK-NOT: linalg.generic @@ -202,15 +207,14 @@ %arg1 : tensor) -> tensor { - %0 = linalg.generic - {args_in = 2 : i64, args_out = 1 : i64, + %0 = linalg.generic { indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - %arg0, %arg1 { + ins(%arg0, %arg1 : tensor, tensor) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %1 = mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 - }: tensor, tensor -> tensor + } -> tensor %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k, l)>] : tensor into tensor @@ -229,15 +233,14 @@ func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>) -> tensor<8x33x4xf32> { %cst = constant dense<2.000000e+00> : tensor<264x4xf32> - %0 = linalg.generic - {args_in = 2 : i64, args_out = 1 : i64, + %0 = linalg.generic { indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} - %arg0, %cst { + ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors %2 = mulf %arg1, %arg2 : f32 linalg.yield %2 : f32 - }: tensor<264x4xf32>, tensor<264x4xf32> -> tensor<264x4xf32> + } -> tensor<264x4xf32> %1 = linalg.tensor_reshape %0 [#map1, #map2] : tensor<264x4xf32> into tensor<8x33x4xf32> return %1 : tensor<8x33x4xf32> @@ -251,7 +254,8 @@ // CHECK: %[[CST:.*]] = constant {{.*}} : f32 // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK: tensor<264x4xf32> -> tensor<8x33x4xf32> +// CHECK-SAME: tensor<264x4xf32> +// CHECK: -> tensor<8x33x4xf32> // CHECK-NOT: linalg.tensor_reshape // ----- @@ -261,23 +265,20 @@ func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32> { %0 = constant dense<42.0> : tensor<5xf32> - %1 = linalg.generic - {args_in = 2 : i64, args_out = 1 : i64, + %1 = linalg.generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} - %0, %arg0 { + ins(%0, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>) { ^bb0(%arg1: f32, %arg2: f32): %2 = mulf %arg1, %arg2 : f32 linalg.yield %2 : f32 - }: tensor<5xf32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32> + } -> tensor<5x?x?xf32> return %1 : tensor<5x?x?xf32> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @generic_op_constant_fusion // CHECK: %[[CST:.*]] = constant {{.*}} : f32 // CHECK: linalg.generic -// CHECK-SAME: args_in = 1 : i64 -// CHECK-SAME: args_out = 1 : i64 // CHECK: ^{{.*}}(%[[ARG1:.*]]: f32) // CHECK: mulf %[[CST]], %[[ARG1]] @@ -289,23 +290,20 @@ -> tensor<5x?x?xf32> { %0 = constant dense<42.0> : tensor<5xf32> - %1 = linalg.indexed_generic - {args_in = 2 : i64, args_out = 1 : i64, + %1 = linalg.indexed_generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} - %0, %arg0 { + ins(%0, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>) { ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: f32, %arg5 : f32): %2 = mulf %arg4, %arg5 : f32 linalg.yield %2 : f32 - }: tensor<5xf32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32> + } -> tensor<5x?x?xf32> return %1 : tensor<5x?x?xf32> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @indexed_generic_op_constant_fusion // CHECK: %[[CST:.*]] = constant {{.*}} : f32 // CHECK: linalg.indexed_generic -// CHECK-SAME: args_in = 1 : i64 -// CHECK-SAME: args_out = 1 : i64 // CHECK: ^{{[a-zA-Z0-9_]*}} // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index @@ -321,23 +319,20 @@ -> tensor<5x?x?xf32> { %0 = constant dense<42.0> : tensor - %1 = linalg.generic - {args_in = 2 : i64, args_out = 1 : i64, + %1 = linalg.generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} - %0, %arg0 { + ins(%0, %arg0 : tensor, tensor<5x?x?xf32>) { ^bb0(%arg1: f32, %arg2: f32): %2 = mulf %arg1, %arg2 : f32 linalg.yield %2 : f32 - }: tensor, tensor<5x?x?xf32> -> tensor<5x?x?xf32> + } -> tensor<5x?x?xf32> return %1 : tensor<5x?x?xf32> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @generic_op_zero_dim_constant_fusion // CHECK: %[[CST:.*]] = constant {{.*}} : f32 // CHECK: linalg.generic -// CHECK-SAME: args_in = 1 : i64 -// CHECK-SAME: args_out = 1 : i64 // CHECK: ^{{.*}}(%[[ARG1:.*]]: f32) // CHECK: mulf %[[CST]], %[[ARG1]] @@ -349,23 +344,20 @@ (%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32> { %0 = constant dense<42.0> : tensor - %1 = linalg.indexed_generic - {args_in = 2 : i64, args_out = 1 : i64, + %1 = linalg.indexed_generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} - %0, %arg0 { + ins(%0, %arg0 : tensor, tensor<5x?x?xf32>) { ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4: f32, %arg5: f32): %2 = mulf %arg4, %arg5 : f32 linalg.yield %2 : f32 - }: tensor, tensor<5x?x?xf32> -> tensor<5x?x?xf32> + } -> tensor<5x?x?xf32> return %1 : tensor<5x?x?xf32> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @indexed_generic_op_zero_dim_constant_fusion // CHECK: %[[CST:.*]] = constant {{.*}} : f32 // CHECK: linalg.indexed_generic -// CHECK-SAME: args_in = 1 : i64 -// CHECK-SAME: args_out = 1 : i64 // CHECK: ^{{[a-zA-Z0-9_]*}} // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index @@ -379,34 +371,30 @@ func @generic_op_indexed_generic_op_fusion(%arg0: tensor, %arg1: tensor) { %0 = linalg.generic { - args_in = 2 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"] } %arg0, %arg1 { + iterator_types = ["parallel", "parallel"] } + ins(%arg0, %arg1 : tensor, tensor) { ^bb0(%arg2: i32, %arg3: i32): // no predecessors %10 = addi %arg2, %arg3 : i32 linalg.yield %10 : i32 - } : tensor, tensor -> tensor + } -> tensor %1 = linalg.indexed_generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"] } %0 { + iterator_types = ["parallel", "parallel"] } + ins(%0 : tensor) { ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors %2 = index_cast %arg2 : index to i32 %3 = index_cast %arg3 : index to i32 %4 = addi %arg4, %2 : i32 %5 = subi %4, %3 : i32 linalg.yield %5 : i32 - }: tensor -> tensor + } -> tensor return } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @generic_op_indexed_generic_op_fusion // CHECK-NOT: linalg.generic // CHECK: linalg.indexed_generic -// CHECK-SAME: args_in = 2 -// CHECK-SAME: args_out = 1 // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]] // CHECK: ^{{[a-zA-Z0-9_]*}} // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index @@ -426,33 +414,29 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor, %arg1: tensor) { %0 = linalg.indexed_generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"] } %arg0 { + iterator_types = ["parallel", "parallel"] } + ins(%arg0 : tensor) { ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors %2 = index_cast %arg2 : index to i32 %3 = index_cast %arg3 : index to i32 %4 = addi %arg4, %2 : i32 %5 = subi %4, %3 : i32 linalg.yield %5 : i32 - }: tensor -> tensor + } -> tensor %1 = linalg.generic { - args_in = 2 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"] } %0, %arg1 { + iterator_types = ["parallel", "parallel"] } + ins(%0, %arg1 : tensor, tensor) { ^bb0(%arg2: i32, %arg3: i32): // no predecessors %10 = addi %arg2, %arg3 : i32 linalg.yield %10 : i32 - } : tensor, tensor -> tensor + } -> tensor return } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @indexed_generic_op_generic_op_fusion // CHECK: linalg.indexed_generic -// CHECK-SAME: args_in = 2 -// CHECK-SAME: args_out = 1 // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]] // CHECK: ^{{[a-zA-Z0-9_]*}} // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index @@ -474,36 +458,32 @@ #map1 = affine_map<(d0, d1) -> (d0, d1)> func @indexed_generic_op_fusion(%arg0: tensor) { %0 = linalg.indexed_generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"] } %arg0 { + iterator_types = ["parallel", "parallel"] } + ins(%arg0 : tensor) { ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors %2 = index_cast %arg2 : index to i32 %3 = index_cast %arg3 : index to i32 %4 = addi %arg4, %2 : i32 %5 = subi %4, %3 : i32 linalg.yield %5 : i32 - }: tensor -> tensor + } -> tensor %1 = linalg.indexed_generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map1, #map1], - iterator_types = ["parallel", "parallel"] } %0 { + iterator_types = ["parallel", "parallel"] } + ins(%0 : tensor) { ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors %2 = index_cast %arg2 : index to i32 %3 = index_cast %arg3 : index to i32 %4 = addi %arg4, %2 : i32 %5 = subi %4, %3 : i32 linalg.yield %5 : i32 - }: tensor -> tensor + } -> tensor return } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @indexed_generic_op_fusion // CHECK: linalg.indexed_generic -// CHECK-SAME: args_in = 1 -// CHECK-SAME: args_out = 1 // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] // CHECK: ^{{[a-zA-Z0-9_]*}} // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index @@ -533,23 +513,20 @@ affine_map<(i, j, k, l) -> (l)>] : tensor into tensor %1 = linalg.indexed_generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %0 { + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } + ins(%0 : tensor) { ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors %2 = index_cast %arg2 : index to i32 %3 = addi %arg6, %2 : i32 linalg.yield %3 : i32 - }: tensor -> tensor + } -> tensor return %1 : tensor } // CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion // CHECK-NOT: linalg.tensor_reshape // CHECK: linalg.indexed_generic -// CHECK-SAME: args_in = 1 -// CHECK-SAME: args_out = 1 // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] // CHECK-NOT: linalg.tensor_reshape @@ -562,15 +539,14 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor) -> tensor { %0 = linalg.indexed_generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %arg0 { + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } + ins(%arg0 : tensor) { ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors %2 = index_cast %arg2 : index to i32 %3 = addi %arg6, %2 : i32 linalg.yield %3 : i32 - }: tensor -> tensor + } -> tensor %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k, l)>] : tensor into tensor @@ -580,7 +556,5 @@ // CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion // CHECK-NOT: linalg.tensor_reshape // CHECK: linalg.indexed_generic -// CHECK-SAME: args_in = 1 -// CHECK-SAME: args_out = 1 // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] // CHECK-NOT: linalg.tensor_reshape diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -470,8 +470,6 @@ #id_2d = affine_map<(i, j) -> (i, j)> #pointwise_2d_trait = { - args_in = 2, - args_out = 1, indexing_maps = [#id_2d, #id_2d, #id_2d], iterator_types = ["parallel", "parallel"] } @@ -483,13 +481,14 @@ %c0 = constant 0 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.generic #pointwise_2d_trait %A, %A, %B { + linalg.generic #pointwise_2d_trait + ins(%A, %A: memref, + memref) + outs(%B : memref) { ^bb0(%E: f32, %arg5: f32, %arg6: f32): // no predecessors %2 = addf %E, %arg5 : f32 linalg.yield %2 : f32 - }: memref, - memref, - memref + } %0 = dim %B, %c0 : memref %1 = dim %B, %c1 : memref scf.for %arg4 = %c0 to %0 step %c2 { @@ -503,13 +502,14 @@ %6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref - linalg.generic #pointwise_2d_trait %4, %5, %6 { + linalg.generic #pointwise_2d_trait + ins(%4, %5: memref, + memref) + outs(%6 : memref) { ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors %7 = mulf %arg6, %arg7 : f32 linalg.yield %7 : f32 - }: memref, - memref, - memref + } } } return @@ -527,8 +527,6 @@ #id_2d = affine_map<(i, j) -> (i, j)> #pointwise_2d_trait = { - args_in = 2, - args_out = 1, indexing_maps = [#id_2d, #id_2d, #id_2d], iterator_types = ["parallel", "parallel"] } @@ -542,13 +540,13 @@ %C = alloc (%M, %N): memref %D = alloc (%M, %N): memref %E = alloc (%M, %N): memref - linalg.generic #pointwise_2d_trait %A, %A, %B { + linalg.generic #pointwise_2d_trait + ins(%A, %A : memref, memref) + outs(%B : memref) { ^bb0(%e: f32, %arg5: f32, %arg6: f32): // no predecessors %2 = addf %e, %arg5 : f32 linalg.yield %2 : f32 - }: memref, - memref, - memref + } %0 = dim %B, %c0 : memref %1 = dim %B, %c1 : memref scf.for %arg4 = %c0 to %0 step %c2 { @@ -562,13 +560,14 @@ %6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref - linalg.generic #pointwise_2d_trait %4, %5, %6 { + linalg.generic #pointwise_2d_trait + ins(%4, %5: memref, + memref) + outs(%6 : memref) { ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors %7 = mulf %arg6, %arg7 : f32 linalg.yield %7 : f32 - }: memref, - memref, - memref + } } } return @@ -596,25 +595,23 @@ %c1 = constant 1 : index %0 = alloc() {temp = true} : memref<100x10xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "parallel"] - } %arg1, %0 { + iterator_types = ["parallel", "parallel"]} + ins(%arg1 : memref<100xf32>) + outs(%0 : memref<100x10xf32>) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors linalg.yield %arg3 : f32 - }: memref<100xf32>, memref<100x10xf32> + } %1 = alloc() {temp = true} : memref<100x10xf32> linalg.generic { - args_in = 2 : i64, - args_out = 1 : i64, indexing_maps = [#map1, #map1, #map1], - iterator_types = ["parallel", "parallel"] - } %arg0, %0, %1 { + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %0: memref<100x10xf32>, memref<100x10xf32>) + outs(%1 : memref<100x10xf32>) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors %2 = subf %arg3, %arg4 : f32 linalg.yield %2 : f32 - }: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32> + } dealloc %0 : memref<100x10xf32> %2 = dim %1, %c0 : memref<100x10xf32> %3 = dim %1, %c1 : memref<100x10xf32> @@ -627,16 +624,14 @@ %7 = std.subview %arg2[%i, %j][%c1, %c1][%c1, %c1] : memref<100x10xf32> to memref linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map1, #map1], - iterator_types = ["parallel", "parallel"] - } %6, %7 { + iterator_types = ["parallel", "parallel"]} + ins(%6 : memref) + outs(%7 : memref) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %8 = exp %arg3 : f32 linalg.yield %8 : f32 - }: memref, - memref + } } } dealloc %1 : memref<100x10xf32> diff --git a/mlir/test/Dialect/Linalg/fusion_indexed_generic.mlir b/mlir/test/Dialect/Linalg/fusion_indexed_generic.mlir --- a/mlir/test/Dialect/Linalg/fusion_indexed_generic.mlir +++ b/mlir/test/Dialect/Linalg/fusion_indexed_generic.mlir @@ -3,8 +3,6 @@ #map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> #id_2d = affine_map<(d0, d1) -> (d0, d1)> #pointwise_2d_trait = { - args_in = 2, - args_out = 1, indexing_maps = [#id_2d, #id_2d, #id_2d], iterator_types = ["parallel", "parallel"] } @@ -12,11 +10,13 @@ %B: memref, %C: memref, %D: memref) { - linalg.generic #pointwise_2d_trait %A, %B, %C { + linalg.generic #pointwise_2d_trait + ins(%A, %B: memref, memref) + outs(%C : memref) { ^bb0(%e: f32, %arg5: f32, %arg6: f32): // no predecessors %2 = addf %e, %arg5 : f32 linalg.yield %2 : f32 - }: memref, memref, memref + } %c1 = constant 1 : index %c0 = constant 0 : index %c25 = constant 25 : index @@ -33,10 +33,9 @@ memref to memref linalg.indexed_generic { indexing_maps = [#id_2d, #id_2d], - iterator_types = ["parallel", "parallel"], - args_in = 1, - args_out = 1 - } %4, %5 { + iterator_types = ["parallel", "parallel"]} + ins(%4 : memref) + outs(%5 : memref) { ^bb0(%arg4: index, %arg5: index, %arg6: f32, %arg7: f32): %6 = addi %arg4, %arg2 : index %7 = addi %arg5, %arg3 : index @@ -46,7 +45,7 @@ %11 = sitofp %10 : i32 to f32 %12 = addf %9, %11 : f32 linalg.yield %12 : f32 - }: memref, memref + } } } return @@ -66,8 +65,6 @@ #map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> #id_2d = affine_map<(d0, d1) -> (d0, d1)> #pointwise_2d_trait = { - args_in = 2, - args_out = 1, indexing_maps = [#id_2d, #id_2d, #id_2d], iterator_types = ["parallel", "parallel"] } @@ -79,14 +76,16 @@ %c0 = constant 0 : index %c25 = constant 25 : index %c10 = constant 10 : index - linalg.indexed_generic #pointwise_2d_trait %A, %B, %C { + linalg.indexed_generic #pointwise_2d_trait + ins(%A, %B : memref, memref) + outs(%C : memref) { ^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors %i_int = index_cast %i: index to i32 %i_float = sitofp %i_int : i32 to f32 %ab = addf %a, %b : f32 %out = addf %ab, %i_float : f32 linalg.yield %out : f32 - }: memref, memref, memref + } %C_X = dim %C, %c0 : memref %C_Y = dim %C, %c1 : memref %D_X = dim %D, %c0 : memref @@ -98,14 +97,13 @@ memref to memref linalg.generic { indexing_maps = [#id_2d, #id_2d], - iterator_types = ["parallel", "parallel"], - args_in = 1, - args_out = 1 - } %C_view, %D_view { + iterator_types = ["parallel", "parallel"]} + ins(%C_view : memref) + outs(%D_view : memref) { ^bb0( %a: f32, %b: f32): %ab = addf %a, %b : f32 linalg.yield %ab : f32 - }: memref, memref + } } return } @@ -125,8 +123,6 @@ #map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> #id_2d = affine_map<(d0, d1) -> (d0, d1)> #pointwise_2d_trait = { - args_in = 2, - args_out = 1, indexing_maps = [#id_2d, #id_2d, #id_2d], iterator_types = ["parallel", "parallel"] } @@ -137,14 +133,16 @@ %c1 = constant 1 : index %c3 = constant 3 : index %c0 = constant 0 : index - linalg.indexed_generic #pointwise_2d_trait %A, %B, %C { + linalg.indexed_generic #pointwise_2d_trait + ins(%A, %B: memref, memref) + outs(%C : memref) { ^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors %j_int = index_cast %j: index to i32 %j_float = sitofp %j_int : i32 to f32 %ab = addf %a, %b : f32 %out = addf %ab, %j_float : f32 linalg.yield %out : f32 - }: memref, memref, memref + } %C_X = dim %C, %c0 : memref %C_Y = dim %C, %c1 : memref %D_X = dim %D, %c0 : memref @@ -161,14 +159,13 @@ linalg.generic { indexing_maps = [#id_2d, #id_2d], - iterator_types = ["parallel", "parallel"], - args_in = 1, - args_out = 1 - } %C_view, %D_view { + iterator_types = ["parallel", "parallel"]} + ins(%C_view : memref) + outs(%D_view : memref) { ^bb0( %a: f32, %b: f32): %ab = addf %a, %b : f32 linalg.yield %ab : f32 - }: memref, memref + } scf.yield } return diff --git a/mlir/test/Dialect/Linalg/inlining.mlir b/mlir/test/Dialect/Linalg/inlining.mlir --- a/mlir/test/Dialect/Linalg/inlining.mlir +++ b/mlir/test/Dialect/Linalg/inlining.mlir @@ -9,8 +9,6 @@ ] #trait = { - args_in = 1, - args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel"] } @@ -23,9 +21,11 @@ func @inlined_fn(%arg0: memref) { // CHECK: linalg.generic - linalg.generic #trait %arg0, %arg0 { + linalg.generic #trait + ins(%arg0 : memref) + outs(%arg0 : memref) { ^bb(%0 : f32, %1 : f32) : linalg.yield %0 : f32 - } : memref, memref + } return } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -62,52 +62,24 @@ // ----- func @generic_no_region(%arg0: memref) { - // expected-error @+6 {{expected '{' to begin a region}} + // expected-error @+5 {{expected '{' to begin a region}} linalg.generic { - args_in = 1, - args_out = 1, indexing_maps = [ affine_map<() -> (0)> ], iterator_types = [] - } %arg0 : memref -} - -// ----- - -func @generic_at_least_2_operands(%arg0: memref) { - // expected-error @+1 {{op expected 2 or more operands}} - linalg.generic { - args_in = 1, - args_out = 1, - indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %arg0 {} : memref -} - -// ----- - -func @generic_exactly_2_views(%arg0: memref) { - // expected-error @+1 {{op expected exactly 2 inputs (tensor or buffer) and output buffer operands}} - linalg.generic { - args_in = 1, - args_out = 1, - indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %arg0, %arg0, %arg0 {}: memref, memref, memref + } ins(%arg0 : memref) } // ----- func @generic_mismatched_num_returns(%arg0: memref) { - // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (0)}} + // expected-error @+6 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (0)}} linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ affine_map<() -> ()> ], - iterator_types = [] - } %arg0 { + indexing_maps = [ affine_map<() -> ()> ], + iterator_types = []} + outs(%arg0 : memref) { ^bb(%0: f32): linalg.yield - }: memref + } } // ----- @@ -115,14 +87,12 @@ func @generic_symbol_in_map(%arg0: memref) { // expected-error @+1 {{expected the number of symbols in indexing_map #0 to match rank of operand `symbol_source`}} linalg.generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<()[N] -> (0)> ], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + outs(%arg0 : memref) { ^bb(%i : i32): linalg.yield %i : i32 - }: memref + } } // ----- @@ -130,15 +100,13 @@ func @generic_symbol_source_out_of_range(%arg0: memref) { // expected-error @+1 {{symbol_source index out of range}} linalg.generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<()[N] -> (0)> ], iterator_types = ["parallel"], - symbol_source = 1 - } %arg0 { + symbol_source = 1} + outs(%arg0 : memref) { ^bb(%i : i32): linalg.yield %i : i32 - }: memref + } } // ----- @@ -146,14 +114,12 @@ func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) { // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}} linalg.generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + outs(%arg0 : memref<1xi32>) { ^bb(%i : i32): linalg.yield %i : i32 - }: memref<1xi32> + } } // ----- @@ -161,30 +127,26 @@ func @generic_one_d_view(%arg0: memref(off + i)>>) { // expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref (d0 + s0)>>'}} linalg.generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<() -> (0, 0)> ], - iterator_types = [] - } %arg0 { + iterator_types = []} + outs(%arg0 : memref(off + i)>>) { ^bb(%f : f32): linalg.yield %f: f32 - }: memref(off + i)>> + } } // ----- func @generic_result_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+9 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}} + // expected-error @+7 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}} linalg.generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + outs(%arg0 : memref(off + i)>>) { ^bb(%0: f32): %1 = constant 1: i4 linalg.yield %1: i4 - }: memref(off + i)>> + } } // ----- @@ -192,18 +154,16 @@ func @generic_singular_maps(%arg0: memref(off + i)>>, %arg1: memref(off + i)>>) { // expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}} linalg.generic { - args_in = 1, - args_out = 1, indexing_maps = [ affine_map<(i, j) -> (i + j)>, affine_map<(i, j) -> (i + j)> ], - iterator_types = ["parallel","parallel"] - } %arg0, %arg1 { - ^bb(%0: f32, %1: f32): + iterator_types = ["parallel","parallel"]} + ins(%arg0 : memref(off + i)>>) + outs(%arg1 : memref(off + i)>>) { + ^bb(%0: f32, %1: f32): linalg.yield %1: f32 - }: memref(off + i)>>, - memref(off + i)>> + } } //////////////////////////////////////////////////////////////////////////////// @@ -216,16 +176,15 @@ %f0 = constant 0.0: f32 // expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}} linalg.generic { - args_in = 1, - args_out = 1, indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %arg0, %arg0 { + iterator_types = []} + ins(%arg0 : memref) + outs(%arg0 : memref) { ^bb1: linalg.yield %f0: f32 ^bb2: linalg.yield %f0: f32 - }: memref, memref + } } // ----- @@ -234,12 +193,11 @@ %f0 = constant 0.0: f32 // expected-error @+1 {{linalg.generic' op expected region with 1 block}} linalg.generic { - args_in = 1, - args_out = 1, indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %arg0, %arg0 { - }: memref, memref + iterator_types = []} + ins(%arg0 : memref) + outs(%arg0 : memref) { + } } // ----- @@ -247,14 +205,12 @@ func @generic_mismatched_num_arguments(%arg0: memref) { // expected-error @+1 {{op expected number of block arguments to match number of operands}} linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %arg0 { + indexing_maps = [ affine_map<() -> (0)> ], + iterator_types = []} + outs(%arg0 : memref) { ^bb(%f: f32, %g: f32): linalg.yield %f: f32 - }: memref + } } // ----- @@ -262,14 +218,12 @@ func @generic_block_arg_type(%arg0: memref) { // expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output operand: 'memref'}} linalg.generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %arg0 { + iterator_types = []} + outs(%arg0 : memref) { ^bb(%i: i1): linalg.yield %i : i1 - }: memref + } } // ----- @@ -277,14 +231,12 @@ func @indexed_generic_block_arg_count(%arg0: memref) { // expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}} linalg.indexed_generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<(d0) -> (d0)> ], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + outs(%arg0 : memref) { ^bb(%f: f32): linalg.yield %f : f32 - }: memref + } } // ----- @@ -292,14 +244,12 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref) { // expected-error @+1 {{op expected block argument 1 to be an index}} linalg.indexed_generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<(d0) -> (d0)> ], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + outs(%arg0 : memref) { ^bb(%i: f64, %f: f32): linalg.yield %f: f32 - }: memref + } } // ----- @@ -307,14 +257,12 @@ func @indexed_generic_block_arg_type(%arg0: memref) { // expected-error @+1 {{op expected block argument 2 of the same type as elemental type of output operand: 'memref'}} linalg.indexed_generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<(d0) -> (d0)> ], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + outs(%arg0 : memref) { ^bb(%i: index, %f: i1): linalg.yield %i: index - }: memref + } } // ----- @@ -322,14 +270,12 @@ func @indexed_generic_arg_count(%arg0: memref) { // expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}} linalg.indexed_generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<()[] -> ()> ], - iterator_types = [] - } %arg0 { + iterator_types = []} + outs(%arg0 : memref) { ^bb(%0: index, %1: f32): linalg.yield %1: f32 - } : memref + } return } @@ -338,60 +284,39 @@ func @indexed_generic_induction_var_arg_type(%arg0: memref) { // expected-error @+1 {{op expected block argument 1 to be an index}} linalg.indexed_generic { - args_in = 0, - args_out = 1, iterator_types = ["parallel"], - indexing_maps = [ affine_map<(i) -> (i)> ] - } %arg0 { + indexing_maps = [ affine_map<(i) -> (i)> ]} + outs(%arg0 : memref) { ^bb(%0: i32, %1: f32): linalg.yield %1: f32 - } : memref + } } // ----- func @indexed_generic_result_count(%arg0: memref) { - // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}} + // expected-error @+6 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}} linalg.indexed_generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<(d0) -> (d0)> ], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + outs(%arg0 : memref) { ^bb(%i: index, %val: f32): linalg.yield %val, %val: f32, f32 - }: memref + } } // ----- func @generic_result_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+9 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}} + // expected-error @+7 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}} linalg.generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + outs(%arg0 : memref(off + i)>>) { ^bb(%i: f32): %0 = constant 0: i1 linalg.yield %0: i1 - }: memref(off + i)>> -} - -// ----- - -func @generic_result_tensor_type(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'f32'}} - %0 = linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"] - } %arg0 { - ^bb(%i: f32): - linalg.yield %i: f32 - }: memref(off + i)>> -> f32 + } } // ----- @@ -399,14 +324,12 @@ func @generic_result_tensor_type(%arg0: memref(off + i)>>) { // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'f32'}} %0 = linalg.generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + ins(%arg0 : memref(off + i)>>) { ^bb(%i: f32): linalg.yield %i: f32 - }: memref(off + i)>> -> f32 + } -> f32 } // ----- @@ -415,14 +338,12 @@ // expected-error @+2 {{op expects regions to end with 'linalg.yield', found 'std.addf'}} // expected-note @+1 {{in custom textual format, the absence of terminator implies 'linalg.yield'}} linalg.generic { - args_in = 0, - args_out = 1, indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + outs(%arg0 : memref) { ^bb(%0: i4) : %1 = std.addf %0, %0: i4 - } : memref + } return } @@ -511,23 +432,6 @@ // ----- -func @generic(%arg0: tensor) { - // expected-error @+1 {{unexpected #results > #outputs}} - linalg.generic { - args_in = 1, - args_out = 1, - indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"] - } %arg0 { - ^bb(%0: i4) : - %1 = std.addi %0, %0: i4 - linalg.yield %1, %1: i4, i4 - } : tensor -> (tensor, tensor) - return -} - -// ----- - func @empty_init_expected(%m: memref, %t: tensor) { // expected-error @+1 {{expected empty `init` when op has no results or no reduction dims}} linalg.matmul ins(%m, %m: memref, memref) diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -557,12 +557,15 @@ doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))" } func @generic_region(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.generic #trait2 %arg0, %arg1, %arg2 { + linalg.generic #trait2 + ins(%arg0: memref) + outs(%arg1, %arg2 : memref, + memref) { ^bb0(%a: f32, %b: f32, %c: f32): %d = mulf %a, %b : f32 %e = addf %c, %d : f32 linalg.yield %d, %e : f32, f32 - }: memref, memref, memref + } return } // CHECKLOOP-LABEL: @generic_region @@ -599,7 +602,10 @@ %arg0: memref, %arg1: memref, %arg2: memref) { - linalg.indexed_generic #trait4 %arg0, %arg1, %arg2 { + linalg.indexed_generic #trait4 + ins(%arg0 : memref) + outs(%arg1, %arg2 : memref, + memref) { ^bb0(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32): %result_1 = mulf %a, %b : f32 @@ -610,9 +616,7 @@ %result_2 = addf %c, %ijk_float : f32 linalg.yield %result_1, %result_2 : f32, f32 - }: memref, - memref, - memref + } return } @@ -663,10 +667,12 @@ func @generic_op_zero_rank(%arg0: memref, %arg1: memref<3x4xf32>) { - linalg.generic #trait_broadcast %arg0, %arg1 { + linalg.generic #trait_broadcast + ins(%arg0 : memref) + outs(%arg1 : memref<3x4xf32>) { ^bb(%a: f32, %b: f32) : linalg.yield %a : f32 - } : memref, memref<3x4xf32> + } return } @@ -687,13 +693,15 @@ func @indexed_generic_op_zero_rank(%arg0: memref, %arg1: memref<3x4xi32>) { - linalg.indexed_generic #trait_broadcast %arg0, %arg1 { + linalg.indexed_generic #trait_broadcast + ins(%arg0 : memref) + outs(%arg1 : memref<3x4xi32>) { ^bb(%i: index, %j: index, %a: i32, %b: i32) : %ij = addi %i, %j : index %ij_int = index_cast %ij : index to i32 %result = addi %a, %ij_int : i32 linalg.yield %result : i32 - } : memref, memref<3x4xi32> + } return } @@ -733,11 +741,13 @@ func @generic_op_1D_reduce(%arg0: memref, %arg1: memref) { - linalg.generic #trait_reduce_1D %arg0, %arg1 { + linalg.generic #trait_reduce_1D + ins(%arg0 : memref) + outs(%arg1 : memref) { ^bb(%a: f32, %b: f32) : %0 = addf %a, %b : f32 linalg.yield %0 : f32 - } : memref, memref + } return } // CHECKLOOP-LABEL: @generic_op_1D_reduce @@ -777,14 +787,16 @@ %arg1: memref, %arg2: memref) { - linalg.indexed_generic #trait_reduce_init_1D %arg0, %arg1, %arg2 { + linalg.indexed_generic #trait_reduce_init_1D + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) { ^bb(%i : index, %a: f32, %b: f32, %c: f32) : %0 = constant 0 : index %1 = cmpi "eq", %0, %i : index %2 = select %1, %b, %c : f32 %3 = addf %a, %2 : f32 linalg.yield %3 : f32 - } : memref, memref, memref + } return } // CHECKLOOP-LABEL: @indexed_generic_op_1D_reduce @@ -820,10 +832,10 @@ } func @generic_const_init(%arg0: memref) { %cst = constant 1.0 : f32 - linalg.generic #trait_const_fill %arg0 { + linalg.generic #trait_const_fill outs(%arg0 : memref) { ^bb0(%arg1: f32): // no predecessors linalg.yield %cst : f32 - }: memref + } return } // CHECKLOOP-LABEL: @generic_const_init @@ -852,11 +864,13 @@ } func @scalar_code(%arg0: memref, %arg1 : memref, %arg2 : memref) { - linalg.generic #scalar_trait %arg0, %arg1, %arg2 { + linalg.generic #scalar_trait + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) { ^bb(%a : f32, %b : f32, %c : f32) : %0 = addf %a, %b : f32 linalg.yield %0 : f32 - } : memref, memref, memref + } return } // CHECKLOOP-LABEL: @scalar_code @@ -941,14 +955,14 @@ } func @conv1d(%in : memref, %filter : memref, %out : memref) -> () { - linalg.generic #conv_1d_trait %in, %filter, %out { + linalg.generic #conv_1d_trait + ins(%in, %filter : memref, memref) + outs(%out : memref) { ^bb0(%a: f32, %b: f32, %c: f32) : %d = mulf %a, %b : f32 %e = addf %c, %d : f32 linalg.yield %e : f32 - } : memref, - memref, - memref + } return } @@ -1009,14 +1023,14 @@ } func @conv2d(%in : memref, %filter : memref, %out : memref) -> () { - linalg.generic #conv_2d_trait %in, %filter, %out { + linalg.generic #conv_2d_trait + ins(%in, %filter : memref, memref) + outs(%out : memref) { ^bb0(%a: f32, %b: f32, %c: f32) : %d = mulf %a, %b : f32 %e = addf %c, %d : f32 linalg.yield %e : f32 - } : memref, - memref, - memref + } return } @@ -1093,14 +1107,14 @@ } func @conv3d(%in : memref, %filter : memref, %out : memref) -> () { - linalg.generic #conv_3d_trait %in, %filter, %out { + linalg.generic #conv_3d_trait + ins(%in, %filter : memref, memref) + outs(%out : memref) { ^bb0(%a: f32, %b: f32, %c: f32) : %d = mulf %a, %b : f32 %e = addf %c, %d : f32 linalg.yield %e : f32 - } : memref, - memref, - memref + } return } @@ -1193,14 +1207,14 @@ } func @conv4d(%in : memref, %filter : memref, %out : memref) -> () { - linalg.generic #conv_4d_trait %in, %filter, %out { + linalg.generic #conv_4d_trait + ins(%in, %filter : memref, memref) + outs(%out : memref) { ^bb0(%a: f32, %b: f32, %c: f32) : %d = mulf %a, %b : f32 %e = addf %c, %d : f32 linalg.yield %e : f32 - } : memref, - memref, - memref + } return } diff --git a/mlir/test/Dialect/Linalg/parallel_loops.mlir b/mlir/test/Dialect/Linalg/parallel_loops.mlir --- a/mlir/test/Dialect/Linalg/parallel_loops.mlir +++ b/mlir/test/Dialect/Linalg/parallel_loops.mlir @@ -5,15 +5,14 @@ %rhs: memref<2x2xf32>, %sum: memref<2x2xf32>) { linalg.generic { - args_in = 2 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"] - } %lhs, %rhs, %sum { + iterator_types = ["parallel", "parallel"]} + ins(%lhs, %rhs : memref<2x2xf32>, memref<2x2xf32>) + outs(%sum : memref<2x2xf32>) { ^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32): // no predecessors %0 = addf %lhs_in, %rhs_in : f32 linalg.yield %0 : f32 - }: memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32> + } return } // CHECK-LABEL: @linalg_generic_sum @@ -35,17 +34,17 @@ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> ] #trait = { - args_in = 1, - args_out = 1, iterator_types = ["parallel", "parallel", "reduction", "parallel"], indexing_maps = #accesses } func @lower_outer_parallel(%A: memref, %B: memref) { - linalg.generic #trait %A, %B { + linalg.generic #trait + ins(%A : memref) + outs(%B : memref) { ^bb0(%a: f32, %b: f32): linalg.yield %a: f32 - } : memref, memref + } return } // CHECK-LABEL: @lower_outer_parallel @@ -68,17 +67,17 @@ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)> ] #trait = { - args_in = 1, - args_out = 1, iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], indexing_maps = #accesses } func @lower_mixed_parallel(%A: memref, %B: memref) { - linalg.generic #trait %A, %B { + linalg.generic #trait + ins(%A : memref) + outs(%B : memref) { ^bb0(%a: f32, %b: f32): linalg.yield %a: f32 - } : memref, memref + } return } // CHECK-LABEL: @lower_mixed_parallel diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -293,8 +293,6 @@ ] #trait = { - args_in = 1, - args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1" @@ -302,37 +300,44 @@ func @generic(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.generic #trait {foo = 1} %arg0, %arg1 { + linalg.generic #trait + ins(%arg0 : memref, offset: ?, strides: [?, 1]>) + outs(%arg1 : memref) + attrs = {foo = 1} { ^bb(%0: vector<3x4xi4>, %1: f32) : %f0 = constant 0.0 : f32 linalg.yield %f0 : f32 - } : memref, offset: ?, strides: [?, 1]>, - memref + } return } // CHECK-LABEL: func @generic -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, -// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], -// CHECK-SAME: library_call = "some_external_function_name_1" +// CHECK: linalg.generic { +// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"], +// CHECK-SAME: library_call = "some_external_function_name_1"} +// CHECK-SAME: ins({{.*}} : memref, #[[$strided2D]]>) +// CHECK-SAME: outs({{.*}} : memref) // CHECK-SAME: {foo = 1 : i64} -// CHECK: memref, #[[$strided2D]]>, memref func @generic_with_tensor_input(%arg0: tensor>, %arg1: memref) { - linalg.generic #trait {foo = 1} %arg0, %arg1 { + linalg.generic #trait + ins(%arg0 : tensor>) + outs(%arg1 : memref) + attrs = {foo = 1} { ^bb(%0: vector<3x4xi4>, %1: f32) : %f0 = constant 0.0 : f32 linalg.yield %f0 : f32 - } : tensor>, - memref + } return } // CHECK-LABEL: func @generic_with_tensor_input -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_1"} +// CHECK-SAME: ins({{.*}} : tensor>) +// CHECK-SAME: outs({{.*}} : memref) // CHECK-SAME: {foo = 1 : i64} -// CHECK: tensor>, memref // ----- @@ -342,8 +347,6 @@ ] #trait2 = { - args_in = 2, - args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1" @@ -352,20 +355,22 @@ func @generic_with_tensor_input_and_output( %arg0: tensor>, %arg1: tensor) -> (tensor) { - %0 = linalg.generic #trait2 {foo = 1} %arg0, %arg1 { + %0 = linalg.generic #trait2 + ins(%arg0, %arg1 : tensor>, tensor) + attrs = {foo = 1} { ^bb(%0: vector<3x4xi4>, %1: f32) : %f0 = constant 0.0 : f32 linalg.yield %f0 : f32 - } : tensor>, tensor -> tensor + } -> tensor return %0 : tensor } // CHECK-LABEL: func @generic_with_tensor_input_and_output -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_1"} +// CHECK-SAME: ins({{.*}} : tensor>, tensor) // CHECK-SAME: {foo = 1 : i64} -// CHECK-SAME: %{{.*}}, %{{.*}} -// CHECK: tensor>, tensor -> tensor +// CHECK: -> tensor // CHECK: return {{.*}} : tensor // ----- @@ -376,8 +381,6 @@ ] #trait2 = { - args_in = 2, - args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1" @@ -386,20 +389,22 @@ func @indexed_generic_with_tensor_input_and_output( %arg0: tensor>, %arg1: tensor) -> (tensor) { - %0 = linalg.indexed_generic #trait2 {foo = 1} %arg0, %arg1 { + %0 = linalg.indexed_generic #trait2 + ins(%arg0, %arg1 : tensor>, tensor) + attrs = {foo = 1} { ^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) : %f0 = constant 0.0 : f32 linalg.yield %f0 : f32 - } : tensor>, tensor -> tensor + } -> tensor return %0 : tensor } // CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output -// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.indexed_generic { // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_1"} +// CHECK-SAME: ins({{.*}} : tensor>, tensor) // CHECK-SAME: {foo = 1 : i64} -// CHECK-SAME: %{{.*}}, %{{.*}} -// CHECK: tensor>, tensor -> tensor +// CHECK: -> tensor // CHECK: return {{.*}} : tensor // ----- @@ -410,8 +415,6 @@ ] #trait_broadcast = { - args_in = 1, - args_out = 1, indexing_maps = #broadcast_access, iterator_types = ["parallel", "parallel"], library_call = "some_broadcast_external_fn" @@ -419,19 +422,21 @@ func @generic_op_zero_rank(%arg0: tensor) -> (tensor<3x4xf32>) { - %0 = linalg.generic #trait_broadcast %arg0 { + %0 = linalg.generic #trait_broadcast + ins(%arg0 : tensor) { ^bb(%a: f32) : linalg.yield %a : f32 - } : tensor -> tensor<3x4xf32> + } -> tensor<3x4xf32> return %0 : tensor<3x4xf32> } func @indexed_generic_op_zero_rank(%arg0: tensor) -> (tensor<3x4xf32>) { - %0 = linalg.indexed_generic #trait_broadcast %arg0 { + %0 = linalg.indexed_generic #trait_broadcast + ins(%arg0 : tensor) { ^bb(%i: index, %j: index, %a: f32) : linalg.yield %a : f32 - } : tensor -> tensor<3x4xf32> + } -> tensor<3x4xf32> return %0 : tensor<3x4xf32> } @@ -446,8 +451,6 @@ ] #trait3 = { - args_in = 1, - args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2" @@ -455,41 +458,48 @@ func @generic_region(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.generic #trait3 {foo = 1} %arg0, %arg1 { + linalg.generic #trait3 + ins(%arg0 : memref, offset: ?, strides: [?, 1]>) + outs(%arg1 : memref) + attrs = {foo = 1} { ^bb(%a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 - } : memref, offset: ?, strides: [?, 1]>, - memref + } return } // CHECK-LABEL: func @generic_region -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, -// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], +// CHECK: linalg.generic { +// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_2" -// CHECK-SAME: {foo = 1 : i64} -// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): -// CHECK: linalg.yield %{{.*}} : f32 -// CHECK: memref, #[[$strided2D]]>, -// CHECK-SAME: memref +// CHECK-SAME: ins({{.*}} : memref, #[[$strided2D]]>) +// CHECK-SAME: outs({{.*}} : memref) +// CHECK-SAME: attrs = {foo = 1 : i64} { +// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): +// CHECK: linalg.yield %{{.*}} : f32 func @indexed_generic(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.indexed_generic #trait3 {foo = 1} %arg0, %arg1 { - ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) : + linalg.indexed_generic #trait3 + ins(%arg0 : memref, offset: ?, strides: [?, 1]>) + outs(%arg1 : memref) + attrs = {foo = 1} { + ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 - }: memref, offset: ?, strides: [?, 1]>, - memref + } return } // CHECK-LABEL: func @indexed_generic -// CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64, -// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], +// CHECK: linalg.indexed_generic { +// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_2" +// CHECK-SAME: ins({{.*}} : memref, #[[$strided2D]]>) +// CHECK-SAME: outs({{.*}} : memref) // CHECK-SAME: {foo = 1 : i64} // CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // CHECK: linalg.yield %{{.*}} : f32 -// CHECK: }: memref, #[[$strided2D]]>, -// CHECK-SAME: memref +// CHECK: } // ----- diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir --- a/mlir/test/Dialect/Linalg/standard.mlir +++ b/mlir/test/Dialect/Linalg/standard.mlir @@ -72,8 +72,6 @@ affine_map<(m, n, k) -> (m, n)> ] #matmul_trait = { - args_in = 2, - args_out = 1, iterator_types = ["parallel", "parallel", "reduction"], indexing_maps = #matmul_accesses, library_call = "external_outerproduct_matmul" @@ -88,20 +86,19 @@ !matrix_type_C = type memref func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) { - linalg.generic #matmul_trait %A, %B, %C { + linalg.generic #matmul_trait + ins(%A, %B : !matrix_type_A, !matrix_type_B) + outs(%C : !matrix_type_C) { ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C): %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B linalg.yield %d: !vector_type_C - } : !matrix_type_A, !matrix_type_B, !matrix_type_C - + } return } // CHECK-LABEL: func @matmul_vec_impl( // CHECK: call @external_outerproduct_matmul(%{{.*}}) : #indexed_matmul_trait = { - args_in = 2, - args_out = 1, iterator_types = ["parallel", "parallel", "reduction"], indexing_maps = #matmul_accesses, library_call = "external_indexed_outerproduct_matmul" @@ -109,12 +106,14 @@ func @matmul_vec_indexed(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) { - linalg.indexed_generic #indexed_matmul_trait %A, %B, %C { + linalg.indexed_generic #indexed_matmul_trait + ins(%A, %B : !matrix_type_A, !matrix_type_B) + outs(%C : !matrix_type_C) { ^bb0(%i: index, %j: index, %k: index, %a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C): %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B linalg.yield %d: !vector_type_C - } : !matrix_type_A, !matrix_type_B, !matrix_type_C + } return } // CHECK-LABEL: func @matmul_vec_indexed( diff --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir --- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir +++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir @@ -4,23 +4,24 @@ // CHECK-LABEL: func @multiple_results_generic_op func @multiple_results_generic_op(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0, %1 = linalg.generic {args_in = 1 : i64, args_out = 2 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %arg0 { + %0, %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} + ins(%arg0 : tensor<4xf32>) { ^bb0(%gen_arg1: f32): %tmp1 = exp %gen_arg1 : f32 linalg.yield %tmp1, %tmp1 : f32, f32 - }: tensor<4xf32> -> (tensor<4xf32>, tensor<4xf32>) + } -> tensor<4xf32>, tensor<4xf32> return %0, %1 : tensor<4xf32>, tensor<4xf32> } // CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]], %[[ARG2_RESULT:.*]]: [[TYPE]]) // CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]] // CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]] // CHECK: linalg.generic -// CHECK-SAME: %[[NEW_ARG0]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]] +// CHECK-SAME: ins(%[[NEW_ARG0]] : [[TYPE]] +// CHECK-SAME: outs(%[[FIRST_ALLOC]], %[[SECOND_ALLOC]] : [[TYPE]], [[TYPE]] // CHECK-NEXT: ^{{[a-z0-9_]*}} // CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32 // CHECK-NEXT: %{{.*}} = exp // CHECK-NEXT: linalg.yield -// CHECK-NEXT: [[TYPE]], [[TYPE]], [[TYPE]] // CHECK: linalg.copy(%[[FIRST_ALLOC]], %[[ARG1_RESULT]]) // CHECK: dealloc %[[FIRST_ALLOC]] // CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[ARG2_RESULT]]) @@ -33,31 +34,33 @@ // CHECK-LABEL: func @chained_operations func @chained_operations(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + %0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg0 : tensor<4xf32>) { ^bb0(%gen_arg1: f32): %tmp1 = exp %gen_arg1 : f32 linalg.yield %tmp1 : f32 - }: tensor<4xf32> -> tensor<4xf32> - %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 { + } -> tensor<4xf32> + %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%0 : tensor<4xf32>) { ^bb0(%gen_arg2: f32): %tmp2 = exp %gen_arg2 : f32 linalg.yield %tmp2 : f32 - }: tensor<4xf32> -> tensor<4xf32> + } -> tensor<4xf32> return %1 : tensor<4xf32> } // CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]]) // CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]] // CHECK: linalg.generic -// CHECK-SAME: %[[NEW_ARG0]], %[[FIRST_ALLOC]] +// CHECK-SAME: ins(%[[NEW_ARG0]] : [[TYPE]] +// CHECK-SAME: outs(%[[FIRST_ALLOC]] : [[TYPE]] // CHECK: ^{{[a-z0-9_]*}} // CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32 -// CHECK: [[TYPE]], [[TYPE]] // CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]] // CHECK: linalg.generic -// CHECK-SAME: %[[FIRST_ALLOC]], %[[SECOND_ALLOC]] +// CHECK-SAME: ins(%[[FIRST_ALLOC]] : [[TYPE]] +// CHECK-SAME: outs(%[[SECOND_ALLOC]] : [[TYPE]] // CHECK: ^{{[a-z0-9_]*}} // CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32 -// CHECK: [[TYPE]], [[TYPE]] // CHECK: dealloc %[[FIRST_ALLOC]] // CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[ARG1_RESULT]]) // CHECK: dealloc %[[SECOND_ALLOC]] diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -349,11 +349,13 @@ func @pointwise(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.generic #pointwise_2d_trait %arg0, %arg1, %arg2 { + linalg.generic #pointwise_2d_trait + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) { ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): // no predecessors %4 = addf %arg4, %arg5 : f32 linalg.yield %4 : f32 - }: memref, memref, memref + } return } // TILE-2-LABEL: func @pointwise diff --git a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir --- a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir +++ b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir @@ -10,13 +10,15 @@ iterator_types = ["parallel"] } func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) { - linalg.indexed_generic #pointwise_1d_trait %operand, %result { + linalg.indexed_generic #pointwise_1d_trait + ins(%operand :memref<50xf32>) + outs(%result : memref<50xf32>) { ^bb0(%i: index, %operand_in: f32, %result_in: f32): %i_int = index_cast %i: index to i32 %i_float = sitofp %i_int : i32 to f32 %out = addf %operand_in, %i_float : f32 linalg.yield %out : f32 - }: memref<50xf32>, memref<50xf32> + } return } // TILE-10n25-LABEL: func @indexed_generic_vector @@ -53,7 +55,9 @@ iterator_types = ["parallel", "parallel"] } func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) { - linalg.indexed_generic #combined_indices_trait %operand, %result { + linalg.indexed_generic #combined_indices_trait + ins(%operand : memref<50x100xf32>) + outs(%result : memref<50x100xf32>) { ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): %i_int = index_cast %i: index to i32 %i_float = sitofp %i_int : i32 to f32 @@ -61,7 +65,7 @@ %j_float = sitofp %j_int : i32 to f32 %out = addf %i_float, %j_float : f32 linalg.yield %out : f32 - }: memref<50x100xf32>, memref<50x100xf32> + } return } // TILE-10n25-LABEL: func @indexed_generic_matrix diff --git a/mlir/test/Dialect/Linalg/tile_parallel.mlir b/mlir/test/Dialect/Linalg/tile_parallel.mlir --- a/mlir/test/Dialect/Linalg/tile_parallel.mlir +++ b/mlir/test/Dialect/Linalg/tile_parallel.mlir @@ -14,13 +14,14 @@ func @sum(%lhs: memref, %rhs: memref, %sum: memref) { - linalg.generic #pointwise_2d_trait %lhs, %rhs, %sum { + linalg.generic #pointwise_2d_trait + ins(%lhs, %rhs: memref, + memref) + outs(%sum : memref) { ^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32): %result = addf %lhs_in, %rhs_in : f32 linalg.yield %result : f32 - }: memref, - memref, - memref + } return } // TILE-2-LABEL: func @sum( @@ -33,7 +34,7 @@ // TILE-2: [[LHS_SUBVIEW:%.*]] = subview [[LHS]] // TILE-2: [[RHS_SUBVIEW:%.*]] = subview [[RHS]] // TILE-2: [[SUM_SUBVIEW:%.*]] = subview [[SUM]] -// TILE-2: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] { +// TILE-2: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]] // TILE-02-LABEL: func @sum( // TILE-02-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { @@ -45,12 +46,12 @@ // TILE-02: [[LHS_SUBVIEW:%.*]] = subview [[LHS]] // TILE-02: [[RHS_SUBVIEW:%.*]] = subview [[RHS]] // TILE-02: [[SUM_SUBVIEW:%.*]] = subview [[SUM]] -// TILE-02: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] { +// TILE-02: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]] // TILE-002-LABEL: func @sum( // TILE-002-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { // TILE-002-NO: scf.parallel -// TILE-002: linalg.generic {{.*}} [[LHS]], [[RHS]], [[SUM]] { +// TILE-002: linalg.generic {{.*}} ins([[LHS]], [[RHS]]{{.*}} outs([[SUM]] // TILE-234-LABEL: func @sum( // TILE-234-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { @@ -64,4 +65,4 @@ // TILE-234: [[LHS_SUBVIEW:%.*]] = subview [[LHS]] // TILE-234: [[RHS_SUBVIEW:%.*]] = subview [[RHS]] // TILE-234: [[SUM_SUBVIEW:%.*]] = subview [[SUM]] -// TILE-234: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] { +// TILE-234: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]] diff --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir --- a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir +++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir @@ -59,12 +59,14 @@ %arg1 : memref, %arg2 : memref) { - linalg.generic #trait %arg0, %arg1, %arg2 { + linalg.generic #trait + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): %0 = addf %arg3, %arg4 : f32 %1 = addf %0, %arg5 : f32 linalg.yield %1 : f32 - } : memref, memref, memref + } return } @@ -82,7 +84,8 @@ // CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]] // CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]] // CHECK: linalg.generic -// CHECK-SAME: %[[SV1]], %[[SV2]], %[[SV3]] +// CHECK-SAME: ins(%[[SV1]], %[[SV2]] +// CHECK-SAME: outs(%[[SV3]] // TILE1-LABEL: func @reduction // TILE1-DAG: %[[C2:.*]] = constant 2 : index @@ -92,7 +95,8 @@ // TILE1: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0] // TILE1-NOT: subview // TILE1: linalg.generic -// TILE1-SAME: %[[SV1]], %[[SV2]], %{{.*}} +// TILE1-SAME: ins(%[[SV1]], %[[SV2]] +// TILE1-SAME: outs(%{{.*}} // TILE2-LABEL: func @reduction // TILE2-DAG: %[[C2:.*]] = constant 2 : index @@ -105,4 +109,5 @@ // TILE2: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0] // TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]] // TILE2: linalg.generic -// TILE2-SAME: %[[SV1]], %[[SV2]], %[[SV3]] +// TILE2-SAME: ins(%[[SV1]], %[[SV2]] +// TILE2-SAME: outs(%[[SV3]] diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -105,12 +105,14 @@ } func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { - linalg.generic #matmul_trait %A, %B, %C { + linalg.generic #matmul_trait + ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>) + outs(%C : memref<8x32xf32>) { ^bb(%a: f32, %b: f32, %c: f32) : %d = mulf %a, %b: f32 %e = addf %c, %d: f32 linalg.yield %e : f32 - } : memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32> + } return } // CHECK-LABEL: func @vectorization_test @@ -122,12 +124,14 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, %C: memref<8x32xi32>) { - linalg.generic #matmul_trait %A, %B, %C { + linalg.generic #matmul_trait + ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>) + outs(%C : memref<8x32xi32>) { ^bb(%a: i32, %b: i32, %c: i32) : %d = muli %a, %b: i32 %e = addi %c, %d: i32 linalg.yield %e : i32 - } : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> + } return } // CHECK-LABEL: func @vectorization_test_integer @@ -187,23 +191,24 @@ func @permute_generic(%A: memref, %B: memref, %C: memref) { - linalg.generic #generic_matmul_trait %A, %B, %C { + linalg.generic #generic_matmul_trait + ins(%A, %B : memref, + memref) + outs(%C : memref) { ^bb(%a: f32, %b: f32, %c: f32): %d = mulf %a, %b: f32 %e = addf %c, %d: f32 linalg.yield %e: f32 - }: memref, - memref, - memref + } return } // CHECK-LABEL: func @permute_generic -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]], // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], -// CHECK-SAME: library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}} +// CHECK-SAME: library_call = "linalg_matmul"} // CHECK: memref, -// CHECK-SAME: memref, +// CHECK-SAME: memref // CHECK-SAME: memref #indexed_matmul_trait = { @@ -217,23 +222,24 @@ %A: memref, %B: memref, %C: memref) { - linalg.indexed_generic #indexed_matmul_trait %A, %B, %C { + linalg.indexed_generic #indexed_matmul_trait + ins(%A, %B : memref, + memref) + outs(%C : memref) { ^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32): %d = mulf %a, %b: f32 %e = addf %c, %d: f32 linalg.yield %e: f32 - } : memref, - memref, - memref + } return } // CHECK-LABEL: func @permute_generic_indexed -// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.indexed_generic { // CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]], // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], -// CHECK-SAME: library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} +// CHECK-SAME: library_call = "linalg_matmul_indexed"} // CHECK: memref, -// CHECK-SAME: memref, +// CHECK-SAME: memref // CHECK-SAME: memref func @matvec_perm(%A: memref, diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -886,22 +886,25 @@ // clang-format off // CHECK-LABEL: func @linalg_generic_pointwise -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins({{.*}}memref, memref) +// CHECK-SAME: outs({{.*}}memref) // CHECK: addf -// CHECK: }: memref, memref, memref -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins({{.*}}memref, memref) +// CHECK-SAME: outs({{.*}}memref) // CHECK: cmpf "ogt" // CHECK: select -// CHECK: }: memref, memref, memref -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins({{.*}}memref) +// CHECK-SAME: outs({{.*}}memref) // CHECK: tanh -// CHECK: }: memref, memref // clang-format on TEST_FUNC(linalg_generic_pointwise_test) { using namespace edsc; @@ -929,7 +932,7 @@ // clang-format off // CHECK-LABEL: func @linalg_generic_matmul -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} /// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): @@ -958,7 +961,7 @@ // clang-format off // CHECK-LABEL: func @linalg_generic_conv_nhwc -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 * 3 + d4 * 5, d3 * 4 + d5 * 6, d6)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)>], @@ -992,7 +995,7 @@ // clang-format off // CHECK-LABEL: func @linalg_generic_dilated_conv_nhwc -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3 * 3 + d5 * 5, d4 * 4 + d6 * 6, d2)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)>], @@ -1053,30 +1056,30 @@ // clang-format off // CHECK-LABEL: func @linalg_tensors -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} // CHECK: addf // CHECK: }: tensor, memref -> tensor -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} // CHECK: cmpf "ogt" // CHECK: select // CHECK: }: tensor, memref -> tensor -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} // CHECK: tanh // CHECK: }: tensor -> tensor -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, // CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>, // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} // CHECK: mulf // CHECK: }: tensor, memref -> tensor -// CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64, +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, // CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>, // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>, @@ -1103,10 +1106,15 @@ AffineExpr i, j; bindDims(&globalContext(), i, j); StructuredIndexed SA(A), SB(B), SC(tensorType); - linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})); - linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); - linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j})); - Value o1 = linalg_generic_matmul(A, B, tensorType)->getResult(0); + Value added = linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})) + ->getResult(0); + Value maxed = linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), + StructuredIndexed(added)({i, j})) + ->getResult(0); + Value tanhed = linalg_generic_pointwise_tanh(SA({i, j}), + StructuredIndexed(maxed)({i, j})) + ->getResult(0); + Value o1 = linalg_generic_matmul(A, B, tanhed, tensorType)->getResult(0); linalg_generic_matmul(A, B, o1, tensorType); f.print(llvm::outs()); diff --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir --- a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir @@ -19,15 +19,13 @@ func @complex_signature_conversion(%arg0: tensor<5xf32>, %arg1: memref<10xf32>, %arg2: i1, %arg3: f16) -> (i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16) { %0 = alloc() : memref<15xf32> %1 = linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"] - } %arg0 { + iterator_types = ["parallel"]} + ins(%arg0 : tensor<5xf32>) { ^bb0(%gen1_arg0: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: tensor<5xf32> -> tensor<5xf32> + } -> tensor<5xf32> return %arg2, %1, %arg1, %0, %arg3 : i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16 } // CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: f16) diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir --- a/mlir/test/Transforms/buffer-placement-preparation.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation.mlir @@ -17,11 +17,12 @@ // CHECK-LABEL: func @memref_in_function_results func @memref_in_function_results(%arg0: tensor<5xf32>, %arg1: memref<10xf32>) -> (tensor<5xf32>, memref<10xf32>, memref<15xf32>) { %0 = alloc() : memref<15xf32> - %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + %1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg0 : tensor<5xf32>) { ^bb0(%gen1_arg0: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: tensor<5xf32> -> tensor<5xf32> + } -> tensor<5xf32> return %1, %arg1, %0 : tensor<5xf32>, memref<10xf32>, memref<15xf32> } // CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[RESULT:.*]]: memref<5xf32>) @@ -97,23 +98,25 @@ // CHECK-LABEL: func @compute_allocs_position_simple func @compute_allocs_position_simple(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{ - %0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + %0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg0 : tensor<2xf32>) { ^bb0(%gen1_arg0: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: tensor<2xf32> -> tensor<2xf32> - %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 { + } -> tensor<2xf32> + %1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%0 : tensor<2xf32>) { ^bb0(%gen2_arg0: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: tensor<2xf32> -> tensor<2xf32> + } -> tensor<2xf32> return %1 : tensor<2xf32> } // CHECK: (%{{.*}}: {{.*}}, %[[ARG0:.*]]: memref<2xf32>, // CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[FIRST_ALLOC]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[FIRST_ALLOC]] // CHECK: %[[SECOND_ALLOC:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[FIRST_ALLOC]], %[[SECOND_ALLOC]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[FIRST_ALLOC]]{{.*}} outs(%[[SECOND_ALLOC]] // ----- @@ -123,78 +126,86 @@ // CHECK-LABEL: func @compute_allocs_position func @compute_allocs_position(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{ - %0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + %0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg0 : tensor<2xf32>) { ^bb0(%gen1_arg0: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: tensor<2xf32> -> tensor<2xf32> - %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 { + } -> tensor<2xf32> + %1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%0 : tensor<2xf32>) { ^bb0(%gen2_arg0: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: tensor<2xf32> -> tensor<2xf32> + } -> tensor<2xf32> cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): - %2 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + %2 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg0 : tensor<2xf32>) { ^bb0(%gen3_arg0: f32): %tmp3 = exp %gen3_arg0 : f32 linalg.yield %tmp3 : f32 - }: tensor<2xf32> -> tensor<2xf32> - %3 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %2 { + } -> tensor<2xf32> + %3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%2 : tensor<2xf32>) { ^bb0(%gen4_arg0: f32): %tmp4 = exp %gen4_arg0 : f32 linalg.yield %tmp4 : f32 - }: tensor<2xf32> -> tensor<2xf32> + } -> tensor<2xf32> br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): - %4 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + %4 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg0 : tensor<2xf32>) { ^bb0(%gen5_arg0: f32): %tmp5 = exp %gen5_arg0 : f32 linalg.yield %tmp5 : f32 - }: tensor<2xf32> -> tensor<2xf32> - %5 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %4 { + } -> tensor<2xf32> + %5 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%4 : tensor<2xf32>) { ^bb0(%gen6_arg0: f32): %tmp6 = exp %gen6_arg0 : f32 linalg.yield %tmp6 : f32 - }: tensor<2xf32> -> tensor<2xf32> + } -> tensor<2xf32> br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>) ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): - %6 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 { + %6 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg0 : tensor<2xf32>) { ^bb0(%gen7_arg0: f32): %tmp7 = exp %gen7_arg0 : f32 linalg.yield %tmp7 : f32 - }: tensor<2xf32> -> tensor<2xf32> - %7 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %6 { + } -> tensor<2xf32> + %7 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%6 : tensor<2xf32>) { ^bb0(%gen8_arg0: f32): %tmp8 = exp %gen8_arg0 : f32 linalg.yield %tmp8 : f32 - }: tensor<2xf32> -> tensor<2xf32> + } -> tensor<2xf32> return %7 : tensor<2xf32> } // CHECK: (%{{.*}}: {{.*}}, %[[ARG0:.*]]: memref<2xf32>, // CHECK-NEXT: %[[ALLOC0:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC0]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC0]] // CHECK: %[[ALLOC1:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC0]], %[[ALLOC1]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC0]]{{.*}} outs(%[[ALLOC1]] // CHECK: cond_br %{{.*}}, ^[[BB0:.*]]({{.*}}), ^[[BB1:.*]]( // CHECK-NEXT: ^[[BB0]] // CHECK-NEXT: %[[ALLOC2:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC2]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC2]] // CHECK: %[[ALLOC3:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC2]], %[[ALLOC3]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC2]]{{.*}} outs(%[[ALLOC3]] // CHECK: br ^[[EXIT:.*]]({{.*}}) // CHECK-NEXT: ^[[BB1]] // CHECK-NEXT: %[[ALLOC4:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC4]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC4]] // CHECK: %[[ALLOC5:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC4]], %[[ALLOC5]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC4]]{{.*}} outs(%[[ALLOC5]] // CHECK: br ^[[EXIT]] // CHECK-NEXT: ^[[EXIT]] // CHECK-NEXT: %[[ALLOC6:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC6]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC6]] // CHECK: %[[ALLOC7:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC6]], %[[ALLOC7]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC6]]{{.*}} outs(%[[ALLOC7]] // ----- @@ -211,16 +222,12 @@ // CHECK-LABEL: func @callee func @callee(%arg1: tensor<5xf32>) -> tensor<5xf32> { - %0 = linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, - indexing_maps = [#map0, #map0], - iterator_types = ["parallel"] - } %arg1 { + %0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg1 : tensor<5xf32>) { ^bb0(%gen1_arg0: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: tensor<5xf32> -> tensor<5xf32> + } -> tensor<5xf32> return %0 : tensor<5xf32> } // CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>) diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir --- a/mlir/test/Transforms/buffer-placement.mlir +++ b/mlir/test/Transforms/buffer-placement.mlir @@ -24,14 +24,14 @@ ^bb2: %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %0 { + iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () @@ -73,14 +73,14 @@ ^bb2(%0: index): %1 = alloc(%0) : memref linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %1 { + iterator_types = ["parallel"]} + ins(%arg1: memref) + outs(%1: memref) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref, memref + } br ^bb3(%1 : memref) ^bb3(%2: memref): "linalg.copy"(%2, %arg2) : (memref, memref) -> () @@ -141,14 +141,14 @@ ^bb2(%0: index): %1 = alloc(%0) : memref linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %1 { + iterator_types = ["parallel"]} + ins(%arg1: memref) + outs(%1: memref) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref, memref + } cond_br %arg0, ^bb3, ^bb4 ^bb3: br ^bb5(%1 : memref) @@ -224,14 +224,14 @@ ^bb1: %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %0 { + iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } br ^bb2(%0 : memref<2xf32>) ^bb2(%1: memref<2xf32>): "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () @@ -262,14 +262,14 @@ func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %0 { + iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>) ^bb1: br ^bb2(%0 : memref<2xf32>) @@ -300,14 +300,14 @@ func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %0 { + iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } cond_br %arg0, ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) @@ -318,14 +318,14 @@ ^bb3(%5: memref<2xf32>, %6: memref<2xf32>): %7 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %5, %7 { + iterator_types = ["parallel"]} + ins(%5: memref<2xf32>) + outs(%7: memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: memref<2xf32>, memref<2xf32> + } "linalg.copy"(%7, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } @@ -357,14 +357,14 @@ func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %0 { + iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } cond_br %arg0, ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) @@ -401,14 +401,14 @@ func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %0 { + iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } cond_br %arg0, ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) @@ -423,14 +423,14 @@ ^bb5(%7: memref<2xf32>, %8: memref<2xf32>): %9 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %7, %9 { + iterator_types = ["parallel"]} + ins(%7: memref<2xf32>) + outs(%9: memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: memref<2xf32>, memref<2xf32> + } "linalg.copy"(%9, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } @@ -456,32 +456,32 @@ func @redundantOperations(%arg0: memref<2xf32>) { %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg0, %0 { + iterator_types = ["parallel"]} + ins(%arg0: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } %1 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %0, %1 { + iterator_types = ["parallel"]} + ins(%0: memref<2xf32>) + outs(%1: memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: memref<2xf32>, memref<2xf32> + } return } // CHECK: (%[[ARG0:.*]]: {{.*}}) // CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[FIRST_ALLOC]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}}outs(%[[FIRST_ALLOC]] // CHECK: %[[SECOND_ALLOC:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{.*}} %[[FIRST_ALLOC]], %[[SECOND_ALLOC]] +// CHECK-NEXT: linalg.generic {{.*}} ins(%[[FIRST_ALLOC]]{{.*}}outs(%[[SECOND_ALLOC]] // CHECK: dealloc // CHECK-NEXT: dealloc // CHECK-NEXT: return @@ -509,26 +509,26 @@ ^bb1: %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg0, %0 { + iterator_types = ["parallel"]} + ins(%arg0: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } br ^exit(%0 : memref<2xf32>) ^bb2: %1 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg0, %1 { + iterator_types = ["parallel"]} + ins(%arg0: memref<2xf32>) + outs(%1: memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: memref<2xf32>, memref<2xf32> + } br ^exit(%1 : memref<2xf32>) ^exit(%arg2: memref<2xf32>): "linalg.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () @@ -567,14 +567,14 @@ ^bb2: %1 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg0, %1 { + iterator_types = ["parallel"]} + ins(%arg0: memref<2xf32>) + outs(%1: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } dealloc %1 : memref<2xf32> br ^exit(%1 : memref<2xf32>) ^exit(%arg2: memref<2xf32>): @@ -599,14 +599,14 @@ %arg1: memref<2xf32>) { %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg0, %0 { + iterator_types = ["parallel"]} + ins(%arg0: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } "linalg.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () return } @@ -625,14 +625,14 @@ func @moving_invalid_dealloc_op(%arg0 : memref<2xf32>, %arg1: memref<2xf32>) { %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg0, %0 { + iterator_types = ["parallel"]} + ins(%arg0: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } dealloc %0 : memref<2xf32> "linalg.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () return @@ -659,17 +659,21 @@ br ^bb3(%arg1 : memref<2xf32>) ^bb2: %0 = alloc() : memref<2xf32> - linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 { + linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %1 = alloc() : memref<2xf32> - linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %1 { + linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%1: memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: memref<2xf32>, memref<2xf32> + } %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () @@ -679,9 +683,9 @@ // CHECK-NEXT: %[[GENERIC1_ALLOC:.*]] = alloc() // CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]] // CHECK: ^[[BB2]]: -// CHECK-NEXT: linalg.generic {{{.*}}} %[[ARG1]], %[[GENERIC1_ALLOC]] +// CHECK-NEXT: linalg.generic {{{.*}}} ins(%[[ARG1]]{{.*}}outs(%[[GENERIC1_ALLOC]] // CHECK: %[[GENERIC2_ALLOC:.*]] = alloc() -// CHECK-NEXT: linalg.generic {{{.*}}} %[[ARG1]], %[[GENERIC2_ALLOC]] +// CHECK-NEXT: linalg.generic {{{.*}}} ins(%[[ARG1]]{{.*}}outs(%[[GENERIC2_ALLOC]] // CHECK: dealloc %[[GENERIC2_ALLOC]] // CHECK-NEXT: %{{.*}} = exp // CHECK: ^[[BB3:.*]]({{.*}}): @@ -701,11 +705,13 @@ func @memref_in_function_results(%arg0: memref<5xf32>, %arg1: memref<10xf32>, %arg2: memref<5xf32>) -> (memref<10xf32>, memref<15xf32>) { %x = alloc() : memref<15xf32> %y = alloc() : memref<5xf32> - linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %y { + linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg0: memref<5xf32>) + outs(%y: memref<5xf32>) { ^bb0(%arg3: f32, %arg4: f32): %2 = exp %arg3 : f32 linalg.yield %2 : f32 - }: memref<5xf32>, memref<5xf32> + } linalg.copy(%y, %arg2) : memref<5xf32>, memref<5xf32> return %arg1, %x : memref<10xf32>, memref<15xf32> } @@ -946,14 +952,14 @@ ^bb2: %0 = alloca() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %0 { + iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () @@ -975,14 +981,14 @@ func @ifElseAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { %0 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %0 { + iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } cond_br %arg0, ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) @@ -993,14 +999,14 @@ ^bb3(%5: memref<2xf32>, %6: memref<2xf32>): %7 = alloca() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %5, %7 { + iterator_types = ["parallel"]} + ins(%5: memref<2xf32>) + outs(%7: memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: memref<2xf32>, memref<2xf32> + } "linalg.copy"(%7, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } @@ -1021,14 +1027,14 @@ func @ifElseNestedAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { %0 = alloca() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg1, %0 { + iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } cond_br %arg0, ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) @@ -1043,14 +1049,14 @@ ^bb5(%7: memref<2xf32>, %8: memref<2xf32>): %9 = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %7, %9 { + iterator_types = ["parallel"]} + ins(%7: memref<2xf32>) + outs(%9: memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: memref<2xf32>, memref<2xf32> + } "linalg.copy"(%9, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } @@ -1074,17 +1080,21 @@ br ^bb3(%arg1 : memref<2xf32>) ^bb2: %0 = alloc() : memref<2xf32> - linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 { + linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%0: memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %1 = alloca() : memref<2xf32> - linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %1 { + linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} + ins(%arg1: memref<2xf32>) + outs(%1: memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: memref<2xf32>, memref<2xf32> + } %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () @@ -1094,9 +1104,9 @@ // CHECK-NEXT: %[[ALLOC:.*]] = alloc() // CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]] // CHECK: ^[[BB2]]: -// CHECK-NEXT: linalg.generic {{{.*}}} %[[ARG1]], %[[ALLOC]] +// CHECK-NEXT: linalg.generic {{{.*}}} ins(%[[ARG1]]{{.*}}outs(%[[ALLOC]] // CHECK: %[[ALLOCA:.*]] = alloca() -// CHECK-NEXT: linalg.generic {{{.*}}} %[[ARG1]], %[[ALLOCA]] +// CHECK-NEXT: linalg.generic {{{.*}}} ins(%[[ARG1]]{{.*}}outs(%[[ALLOCA]] // CHECK: %{{.*}} = exp // CHECK: ^[[BB3:.*]]({{.*}}): // CHECK: linalg.copy diff --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir --- a/mlir/test/Transforms/copy-removal.mlir +++ b/mlir/test/Transforms/copy-removal.mlir @@ -157,14 +157,14 @@ %temp = alloc() : memref<5xf32> linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %temp, %res { + iterator_types = ["parallel"]} + ins(%temp : memref<5xf32>) + outs(%res : memref<5xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<5xf32>, memref<5xf32> + } dealloc %ret : memref<5xf32> return %temp : memref<5xf32> } @@ -231,18 +231,18 @@ // CHECK-NOT: %{{.*}} = alloc %temp = alloc() : memref<2xf32> // CHECK-NEXT: linalg.generic - // CHECK-SAME: %[[ARG0]], %[[RES]] + // CHECK-SAME: ins(%[[ARG0]]{{.*}}outs(%[[RES]] // CHECK-NOT: linalg.copy(%{{.*}}, %[[RES]]) // CHECK-NOT: dealloc %{{.*}} linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg0, %temp { + iterator_types = ["parallel"]} + ins(%arg0 : memref<2xf32>) + outs(%temp : memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: memref<2xf32>, memref<2xf32> + } "linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> () dealloc %temp : memref<2xf32> // CHECK: return @@ -261,23 +261,23 @@ %to = alloc() : memref<2xf32> %temp = alloc() : memref<2xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg0, %temp { + iterator_types = ["parallel"]} + ins(%arg0 : memref<2xf32>) + outs(%temp : memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 - }: memref<2xf32>, memref<2xf32> + } linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} %arg0, %to { + iterator_types = ["parallel"]} + ins(%arg0 : memref<2xf32>) + outs(%to : memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 - }: memref<2xf32>, memref<2xf32> + } // CHECK: linalg.copy "linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> () dealloc %temp : memref<2xf32> diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -39,6 +39,9 @@ /// Converts tensor-type generic linalg operations to memref ones using /// buffer assignment. + /// TODO: Avoid the copy-pasta by exposing the pattern from BufferPlacement.h + /// This probably requires an OpConversionPattern working on generic + /// Operation*. For now only RewritePattern allow this. class GenericOpConverter : public BufferAssignmentOpConversionPattern { public: @@ -48,34 +51,47 @@ LogicalResult matchAndRewrite(linalg::GenericOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { + linalg::GenericOpAdaptor adaptor(operands, + op.getOperation()->getAttrDictionary()); + + // TODO: support ops with reduction. + if (!op.init_tensors().empty()) + return failure(); + + // All inputs need to be turned into buffers first. Until then, bail out. + if (llvm::any_of(adaptor.inputs(), [](Value in) { + return !in.getType().isa(); + })) + return failure(); + Location loc = op.getLoc(); - ResultRange results = op.getOperation()->getResults(); - SmallVector newArgs, newResults; - newArgs.reserve(operands.size() + results.size()); - newArgs.append(operands.begin(), operands.end()); - newResults.reserve(results.size()); + SmallVector outputBuffers, newOutputBuffers; + outputBuffers.assign(adaptor.output_buffers().begin(), + adaptor.output_buffers().end()); + newOutputBuffers.reserve(op.getNumOutputs()); + newOutputBuffers.append(adaptor.output_buffers().begin(), + adaptor.output_buffers().end()); // Update all types to memref types. - for (auto result : results) { - ShapedType type = result.getType().cast(); - assert(type && "Generic operations with non-shaped typed results are " - "not currently supported."); + for (Type t : op.getResultTypes()) { + auto type = t.cast(); if (!type.hasStaticShape()) return rewriter.notifyMatchFailure( op, "dynamic shapes not currently supported"); auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); auto alloc = rewriter.create(loc, memrefType); - newArgs.push_back(alloc); - newResults.push_back(alloc); + newOutputBuffers.push_back(alloc); } // Generate a new linalg operation that works on buffers. auto linalgOp = rewriter.create( - loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()), - rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(), - op.iterator_types(), op.docAttr(), op.library_callAttr(), - op.symbol_sourceAttr()); + loc, + /*resultTensorTypes=*/ArrayRef{}, + /*inputs=*/adaptor.inputs(), + /*outputBuffers=*/newOutputBuffers, + /*initTensors=*/ValueRange{}, op.indexing_maps(), op.iterator_types(), + op.docAttr(), op.library_callAttr(), op.symbol_sourceAttr()); // Create a new block in the region of the new Generic Op. Block &oldBlock = op.getRegion().front(); @@ -83,23 +99,24 @@ Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), oldBlock.getArgumentTypes()); - // Map the old block arguments to the new ones. - BlockAndValueMapping mapping; - mapping.map(oldBlock.getArguments(), newBlock->getArguments()); - // Add the result arguments to the new block. - for (auto result : newResults) - newBlock->addArgument( - result.getType().cast().getElementType()); + for (Value v : newOutputBuffers) + newBlock->addArgument(v.getType().cast().getElementType()); // Clone the body of the old block to the new block. + BlockAndValueMapping mapping; + for (unsigned i = 0; i < oldBlock.getNumArguments(); i++) + mapping.map(oldBlock.getArgument(i), newBlock->getArgument(i)); + + OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(newBlock); - for (auto &op : oldBlock.getOperations()) - rewriter.clone(op, mapping); + for (auto &op : oldBlock.getOperations()) { + Operation *clonedOp = rewriter.clone(op, mapping); + mapping.map(op.getResults(), clonedOp->getResults()); + } - // Replace the results of the old Generic Op with the results of the new - // one. - rewriter.replaceOp(op, newResults); + // Replace the results of the old op with the new output buffers. + rewriter.replaceOp(op, newOutputBuffers); return success(); } }; diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1449,7 +1449,7 @@ let arguments = (ins Variadic:$inputs, Variadic:$output_buffers, Variadic:$init_tensors); - let results = (outs Variadic:$output_tensors); + let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); let builders = [ OpBuilder<