diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -35,38 +35,9 @@ // interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let methods = [ - InterfaceMethod< - "Query the number of inputs from the current operation.", - "unsigned", "getNumInputs" - >, - InterfaceMethod< - "Query the number of outputs from the current operation.", - "unsigned", "getNumOutputs" - >, - InterfaceMethod< - "Query the number of inputs and outputs from the current operation.", - "unsigned", "getNumInputsAndOutputs" - >, - InterfaceMethod< - "Query the input operands from the current operation.", - "Operation::operand_range", "getInputs" - >, - InterfaceMethod< - "Query the output operands from the current operation.", - "Operation::operand_range", "getOutputs" - >, - InterfaceMethod< - "Query the input and output operands from the current operation.", - "Operation::operand_range", "getInputsAndOutputs" - >, - InterfaceMethod< - "Query the iterator types attribute within the current operation.", - "ArrayAttr", "iterator_types" - >, - InterfaceMethod< - "Query the indexing maps attribute within the current operation.", - "ArrayAttr", "indexing_maps" - >, + //========================================================================// + // Loop types handling. + //========================================================================// InterfaceMethod< "Query the number of parallel loops within the current operation.", "unsigned", "getNumParallelLoops" @@ -82,40 +53,98 @@ InterfaceMethod< "Query the number of loops within the current operation.", "unsigned", "getNumLoops">, + + //========================================================================// + // Input arguments handling. + //========================================================================// + InterfaceMethod< + "Query the number of inputs from the current operation.", + "unsigned", "getNumInputs" + >, InterfaceMethod<"Query the input view at the given index.", "Value ", "getInput", (ins "unsigned":$i) >, - InterfaceMethod<"Query the output view at the given index.", - "Value ", "getOutput", (ins "unsigned":$i) - >, InterfaceMethod<[{ Return the index of the given input value `v`, or `None` if the value is not an input. }], "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$v) >, - InterfaceMethod<[{ - Query the index of the given view value, or `None` if the value is not - a view. - }], - "llvm::Optional<unsigned>", "getIndexOfOutput", (ins "Value ":$view) + InterfaceMethod< + "Query the input operands from the current operation.", + "Operation::operand_range", "getInputs" >, InterfaceMethod<[{ Query the type of the input shape at the given index. }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, - InterfaceMethod<[{ - Query the type of the output view at the given index. - }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, - InterfaceMethod<[{ - Query whether the op has only MemRef input and outputs. - }], "bool", "hasBufferSemantics">, InterfaceMethod<[{ Query the subset of input operands that are of ranked tensor type. }], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">, + + + //========================================================================// + // Output arguments handling. + //========================================================================// + InterfaceMethod< + "Query the number of outputs from the current operation.", + "unsigned", "getNumOutputs" + >, + InterfaceMethod<"Query the output buffer at the given index.", + "Value ", "getOutputBuffer", (ins "unsigned":$i) + >, InterfaceMethod<[{ - Query the subset of output operands that are of ranked tensor type. + Query the index of the given buffer value, or `None` if the value is not + part of the output buffers. + }], + "llvm::Optional<unsigned>", "getIndexOfOutputBuffer", (ins "Value ":$view) + >, + InterfaceMethod<[{ + Query the type of the output buffer at the given index. + }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Query the results that are of ranked tensor type. }], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">, + InterfaceMethod< + "Query the output buffers (operands) from the current operation.", + "Operation::operand_range", "getOutputBuffers" + >, + //========================================================================// + // Input and Output arguments handling. + //========================================================================// + InterfaceMethod< + "Return the number of inputs and outputs, irrespective of their buffer " + "or tensor type.", + "unsigned", "getNumInputsAndOutputs" + >, + InterfaceMethod< + "Return the number of inputs, irrespective of their buffer or tensor " + "type, and output buffers", + "unsigned", "getNumInputsAndOutputBuffers" + >, + InterfaceMethod< + "Return the range over inputs (irrespective of type) and output buffers.", + "Operation::operand_range", "getInputsAndOutputBuffers" + >, + + //========================================================================// + // Other interface methods. + //========================================================================// + InterfaceMethod< + "Query the iterator types attribute within the current operation.", + "ArrayAttr", "iterator_types" + >, + InterfaceMethod< + "Query the indexing maps attribute within the current operation.", + "ArrayAttr", "indexing_maps" + >, + InterfaceMethod<[{ + Query whether the op has only MemRef input and outputs. + }], "bool", "hasBufferSemantics">, + + //========================================================================// + // Other static interface methods. + //========================================================================// StaticInterfaceMethod<[{ Create an operation of the current type with the given location, operands, and attributes. @@ -128,9 +157,6 @@ attributes); }] >, - - /// Clone an operation with the given location and operands. This is used to - /// abstract away the optional underlying region creation. InterfaceMethod<[{ Clone the current operation with the given location and operands. This is used to abstract away the optional underlying region creation. @@ -543,15 +569,20 @@ -> (tensor<?x?xf32>) ``` - In this case, the number of return values must match the number of output - tensor arguments. The semantics is that the `linalg.generic` op - produces (i.e. allocates and fills) its return values. + In this case, the number of outputs (args_out) must match the sum of (1) the + number of output buffer operands and (2) the number of tensor return values. + The semantics is that the `linalg.indexed_generic` op produces (i.e. + allocates and fills) its return values. + Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create - control flow around linalg.generic operations are not expected to mix with - tensors because SSA values do not escape naturally. Still, transformations - and rewrites that take advantage of tensor SSA values are expected to be - useful and will be added in the near future. + transformations can be applied. Such legalization moves tensor return values + into output buffer operands and updates the region argument accordingly. + + Transformations that create control-flow around linalg.indexed_generic + operations are not expected to mix with tensors because SSA values do not + escape naturally. Still, transformations and rewrites that take advantage of + tensor SSA values are expected to be useful and will be added in the near + future. }]; let verifier = [{ return ::verify(*this); }]; } @@ -666,15 +697,20 @@ -> (tensor<?x?xf32>) ``` - In this case, the number of return values must match the number of output - tensor arguments. The semantics is that the `linalg.indexed_generic` op - produces (i.e. allocates and fills) its return values. + In this case, the number of outputs (args_out) must match the sum of (1) the + number of output buffer operands and (2) the number of tensor return values. + The semantics is that the `linalg.indexed_generic` op produces (i.e. + allocates and fills) its return values. + Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create - control flow around linalg.generic operations are not expected to mix with - tensors because SSA values do not escape naturally. Still, transformations - and rewrites that take advantage of tensor SSA values are expected to be - useful and will be added in the near future. + transformations can be applied. Such legalization moves tensor return values + into output buffer operands and updates the region argument accordingly. + + Transformations that create control-flow around linalg.indexed_generic + operations are not expected to mix with tensors because SSA values do not + escape naturally. Still, transformations and rewrites that take advantage of + tensor SSA values are expected to be useful and will be added in the near + future. }]; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -58,16 +58,47 @@ class StructuredOpTraits : public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> { private: - /// Return the number of inputs. For internal use only. + /// Return the number of inputs, irrespective of their buffer or tensor type. + /// For internal use only. unsigned nInputs() { return cast<ConcreteType>(this->getOperation()).getNumInputs(); } - /// Return the number of outputs. For internal use only. + /// Return the number of outputs, irrespective of their buffer or tensor type. + /// For internal use only. unsigned nOutputs() { return cast<ConcreteType>(this->getOperation()).getNumOutputs(); } public: + //==========================================================================// + // Loop types handling. + //==========================================================================// + unsigned getNumParallelLoops() { + return getNumIterators( + getParallelIteratorTypeName(), + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + unsigned getNumReductionLoops() { + return getNumIterators( + getReductionIteratorTypeName(), + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + unsigned getNumWindowLoops() { + return getNumIterators( + getWindowIteratorTypeName(), + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + unsigned getNumLoops() { + return getNumIterators( + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + + //==========================================================================// + // Input arguments handling. + //==========================================================================// + // The `i^th` input argument is always the `i^th` operand regardless of + // whether we have tensors or buffers. + // /// Return the `i`-th input value. Value getInput(unsigned i) { assert(i < nInputs()); @@ -90,83 +121,107 @@ auto range = this->getOperation()->getOperands(); return {range.begin(), range.begin() + nInputs()}; } - /// Return the `i`-th output. - Value getOutput(unsigned i) { + /// Query the subset of input operands that are of ranked tensor type. + SmallVector<RankedTensorType, 4> getInputTensorTypes() { + SmallVector<RankedTensorType, 4> res; + for (Type type : getInputs().getTypes()) + if (auto t = type.template dyn_cast<RankedTensorType>()) + res.push_back(t); + return res; + } + + //==========================================================================// + // Output arguments handling. + //==========================================================================// + // The `i^th` output argument is either and operand or a return value + // depending regardless on the type of buffer or tensor output/return + // arguments. + // + /// Return the `i`-th output, asserts that this is a buffer operand and not + /// a tensor result. + Value getOutputBuffer(unsigned i) { + assert(i + this->getOperation()->getNumResults() < nOutputs() && + "overflowing output buffer index"); return this->getOperation()->getOperand(nInputs() + i); } /// Return the index of `value` in the list of output values if found, /// llvm::None otherwise. - Optional<unsigned> getIndexOfOutput(Value value) { - auto it = llvm::find(getOutputs(), value); - if (it != getOutputs().end()) - return it - getOutputs().begin(); + Optional<unsigned> getIndexOfOutputBuffer(Value value) { + auto it = llvm::find(getOutputBuffers(), value); + if (it != getOutputBuffers().end()) + return it - getOutputBuffers().begin(); return llvm::None; } /// Return the `i`-th output buffer type. - ShapedType getOutputShapedType(unsigned i) { - return getOutput(i)->getType().template cast<ShapedType>(); - } - /// Query whether the op has only MemRef input and outputs. - bool hasBufferSemantics() { - return this->getOperation()->getNumResults() == 0 && - llvm::all_of(getInputsAndOutputs(), - [](Value v) { return v.getType().isa<MemRefType>(); }); + MemRefType getOutputBufferType(unsigned i) { + return getOutputBuffer(i)->getType().template cast<MemRefType>(); } - /// Query the subset of input operands that are of ranked tensor type. - SmallVector<RankedTensorType, 4> getInputTensorTypes() { - SmallVector<RankedTensorType, 4> res; - for (Type type : getInputs().getTypes()) - if (auto t = type.template dyn_cast<RankedTensorType>()) - res.push_back(t); - return res; + /// Return the `i`-th output shaped type, irrespective of buffer of tensor + /// type. + ShapedType getOutputShapedType(unsigned i) { + return getShapedType(i + nInputs()); } - /// Query the subset of output operands that are of ranked tensor type. + /// Query the subset of results that are of ranked tensor type. SmallVector<RankedTensorType, 4> getOutputTensorTypes() { SmallVector<RankedTensorType, 4> res; - for (Type type : getOutputs().getTypes()) - if (auto t = type.template dyn_cast<RankedTensorType>()) - res.push_back(t); + for (Type type : this->getOperation()->getResults().getTypes()) + res.push_back(type.template cast<RankedTensorType>()); return res; } /// Return the range over outputs. - Operation::operand_range getOutputs() { + Operation::operand_range getOutputBuffers() { auto range = this->getOperation()->getOperands(); return {range.begin() + nInputs(), - range.begin() + getNumInputsAndOutputs()}; + range.begin() + getNumInputsAndOutputBuffers()}; } - /// Return the number of inputs and outputs. + + //==========================================================================// + // Input and Output arguments handling. + //==========================================================================// + /// Return the number of inputs and outputs, irrespective of their buffer or + /// tensor type. unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } - /// Return the `i`-th buffer type. - ShapedType getShapedType(unsigned i) { - return (i < nInputs()) ? getInputShapedType(i) - : getOutputShapedType(i - nInputs()); - } - /// Return the range over inputs and outputs. - Operation::operand_range getInputsAndOutputs() { + /// Return the number of inputs, irrespective of their buffer or tensor type, + /// and output buffers. + unsigned getNumInputsAndOutputBuffers() { + assert(this->getOperation()->getNumResults() <= nInputs() + nOutputs()); + return nInputs() + nOutputs() - this->getOperation()->getNumResults(); + ; + } + /// Return the range over inputs (irrespective of type) and output buffers. + Operation::operand_range getInputsAndOutputBuffers() { auto range = this->getOperation()->getOperands(); - return {range.begin(), range.begin() + getNumInputsAndOutputs()}; + return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; } - unsigned getNumParallelLoops() { - return getNumIterators( - getParallelIteratorTypeName(), - cast<ConcreteType>(this->getOperation()).iterator_types()); - } - unsigned getNumReductionLoops() { - return getNumIterators( - getReductionIteratorTypeName(), - cast<ConcreteType>(this->getOperation()).iterator_types()); - } - unsigned getNumWindowLoops() { - return getNumIterators( - getWindowIteratorTypeName(), - cast<ConcreteType>(this->getOperation()).iterator_types()); + /// Return the `i`-th shaped type, there are 3 cases: + /// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise + /// 2. if `i < getNumInputsAndOutputBuffers()` then return the + /// `getOutputBufferType(i - nInputs())`; otherwise + /// 3. return the `i - getNumInputsAndOutputBuffers()` result type. + ShapedType getShapedType(unsigned i) { + if (i < nInputs()) + return getInputShapedType(i); + if (i < getNumInputsAndOutputBuffers()) + return getOutputBufferType(i - nInputs()).template cast<ShapedType>(); + return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()] + .template cast<ShapedType>(); } - unsigned getNumLoops() { - return getNumIterators( - cast<ConcreteType>(this->getOperation()).iterator_types()); + + //==========================================================================// + // Other interface methods. + //==========================================================================// + /// Query whether the op has only buffer inputs and no returns. + bool hasBufferSemantics() { + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(getInputs(), + [](Value v) { return v.getType().isa<MemRefType>(); }); } + + //==========================================================================// + // Other static interface methods. + //==========================================================================// static LogicalResult verifyTrait(Operation *op) { - auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputs(); + auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputBuffers(); if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) return failure(); return success(); diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -119,7 +119,7 @@ template <typename ConcreteOp> SmallVector<Value, 8> getViewSizes(ConcreteOp linalgOp) { SmallVector<Value, 8> res; - for (auto v : linalgOp.getInputsAndOutputs()) { + for (auto v : linalgOp.getInputsAndOutputBuffers()) { MemRefType t = v->getType().template cast<MemRefType>(); for (unsigned i = 0; i < t.getRank(); ++i) res.push_back(edsc::intrinsics::dim(v, i)); diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -139,7 +139,7 @@ } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { - for (auto srcView : src.getOutputs()) { // W + for (auto srcView : src.getOutputBuffers()) { // W // RAW graph for (auto dstView : dst.getInputs()) { // R if (aliases.alias(srcView, dstView)) { // if alias, fill RAW @@ -149,7 +149,7 @@ } } // WAW graph - for (auto dstView : dst.getOutputs()) { // W + for (auto dstView : dst.getOutputBuffers()) { // W if (aliases.alias(srcView, dstView)) { // if alias, fill WAW addDependenceElem(DependenceType::WAW, LinalgOpView{src.getOperation(), srcView}, @@ -167,7 +167,7 @@ } } // WAR graph - for (auto dstView : dst.getOutputs()) { // W + for (auto dstView : dst.getOutputBuffers()) { // W if (aliases.alias(srcView, dstView)) { // if alias, fill WAR addDependenceElem(DependenceType::WAR, LinalgOpView{src.getOperation(), srcView}, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -112,19 +112,21 @@ LogicalResult verifyBlockArgs(GenericOpType op, Block &block); template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (block.getNumArguments() != nViews) - return op.emitOpError( - "expected number of block arguments to match number of views"); + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands) + return op.emitOpError("expected number of block arguments to match number " + "of operands"); - for (unsigned i = 0; i < nViews; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + + auto nInputViews = op.getNumInputs(); + for (unsigned i = 0; i < nOperands; ++i) { auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(i)->getType()) return op.emitOpError("expected block argument ") << i << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -132,19 +134,21 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (block.getNumArguments() != nViews + nLoops) + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands + nLoops) return op.emitOpError( - "expected number of block arguments to match number of views + " + "expected number of block arguments to match number of operands + " "number of loops"); + // Note: the number and type of yield values are checked in the YieldOp. + for (unsigned i = 0; i < nLoops; ++i) { if (!block.getArgument(i)->getType().isIndex()) return op.emitOpError("expected block argument ") - << i << " to be of IndexType"; + << i << " to be an index"; } - for (unsigned i = 0; i < nViews; ++i) { + for (unsigned i = 0; i < nOperands; ++i) { unsigned memrefArgIndex = i + nLoops; auto viewType = op.getShapedType(i); if (viewType.getElementType() != @@ -152,7 +156,7 @@ return op.emitOpError("expected block argument ") << memrefArgIndex << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -161,30 +165,33 @@ LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType); template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (funType.getNumInputs() != nViews) - return op.emitOpError("expected fun arguments to match number of views"); + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands) + return op.emitOpError("expected fun arguments to match number of operands"); if (funType.getNumResults() != op.getNumOutputs()) - return op.emitOpError( - "expected fun results to match number of output views"); - - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(idx) != view.getElementType()) + return op.emitOpError("expected fun results(") + << funType.getNumResults() << ") to match number of outputs(" + << op.getNumOutputs() << ")"; + + auto nInputs = op.getNumInputs(); + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx) != shapedType.getElementType()) return op.emitOpError("expected fun argument ") - << idx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of input " << (idx + 1); + } + + auto nOutputs = op.getNumOutputs(); + // linalg.generic output element types are exactly the function results. + for (unsigned idx = 0; idx < nOutputs; ++idx) { + ShapedType shapedType = op.getShapedType(nInputs + idx); + if (funType.getResult(idx) != shapedType.getElementType()) + return op.emitOpError("expected fun result ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of output " << (idx + 1); } return success(); } @@ -192,37 +199,35 @@ template <> LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) { auto nLoops = op.getNumLoops(); - auto nInputViews = op.getNumInputs(); auto nOutputs = op.getNumOutputs(); - auto nViews = op.getNumInputsAndOutputs(); - if (funType.getNumInputs() != nViews + nLoops) + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands + nLoops) return op.emitOpError( - "expected fun arguments to match number of views + number of loops"); + "expected fun arguments to match number of loops + number of operands"); if (funType.getNumResults() != nOutputs) - return op.emitOpError( - "expected fun results to match number of output views"); - for (unsigned i = 0; i < nLoops; ++i) { + return op.emitOpError("expected fun results to match number of outputs"); + for (unsigned i = 0; i < nLoops; ++i) if (!funType.getInput(i).isIndex()) + return op.emitOpError("expected fun argument ") << i << " to be an index"; + + auto nInputs = op.getNumInputs(); + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx + nLoops) != shapedType.getElementType()) return op.emitOpError("expected fun argument ") - << i << " to be of IndexType"; + << (idx + nLoops + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of input " << (idx + 1); } - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto funIdx = nLoops + idx; - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(funIdx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << funIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + + // linalg.generic output element types are exactly the function results. + for (unsigned idx = 0; idx < nOutputs; ++idx) { + ShapedType shapedType = op.getShapedType(nInputs + idx); + if (funType.getResult(idx) != shapedType.getElementType()) + return op.emitOpError("expected fun result ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of output " << (idx + 1); } return success(); } @@ -231,9 +236,10 @@ LogicalResult verifyGenericOp(GenericOpType op) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (nViews != llvm::size(op.views())) - return op.emitOpError("expected exactly ") << nViews << " view operands"; + auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); + if (nInputsAndOutputBuffers != llvm::size(op.views())) + return op.emitOpError("expected exactly ") + << nInputsAndOutputBuffers << " inputs and buffer operands"; auto ®ion = op.region(); auto funOp = op.getFunction(); @@ -287,22 +293,6 @@ return op.emitOpError("expected the concatenation of maps in indexing_map " "to be invertible"); - auto outputTensorTypes = op.getOutputTensorTypes(); - if (outputTensorTypes.size() != op.getNumResults()) - return op.emitOpError("expected #output tensor operands (") - << outputTensorTypes.size() << ") to match #results (" - << op.getNumResults() << ")"; - - unsigned index = 0; - for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) { - auto resTy = std::get<0>(it); - auto outOpTy = std::get<1>(it); - if (resTy != outOpTy) - return op.emitOpError("result #") - << index << " must be " << outOpTy << ", but got " << resTy; - ++index; - } - return success(); } @@ -691,17 +681,19 @@ template <typename GenericOpType> LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { // The operand number and types must match the view element types. - auto nOutputViews = genericOp.getNumOutputs(); - if (op.getNumOperands() != nOutputViews) - return op.emitOpError("expected ") - << nOutputViews << " operand to match enclosing linalg.generic op"; + auto nOutputs = genericOp.getNumOutputs(); + if (op.getNumOperands() != nOutputs) + return op.emitOpError("expected number of yield values(") + << nOutputs << ") to match the number of operands of the enclosing " + << "linalg.generic op(" << op.getNumOperands() << ")"; - for (unsigned i = 0; i != nOutputViews; ++i) { + for (unsigned i = 0; i != nOutputs; ++i) { auto elementType = genericOp.getOutputShapedType(i).getElementType(); if (op.getOperand(i)->getType() != elementType) - return op.emitOpError("type of return operand ") - << i << " (" << op.getOperand(i)->getType() - << ") doesn't match view element type (" << elementType << ")"; + return op.emitOpError("type of yield operand ") + << i << " (" << op.getOperand(i)->getType() << ") doesn't match " + << "the element type of the enclosing linalg.generic op(" + << elementType << ")"; } return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -72,7 +72,7 @@ clonedViews.reserve(op.getNumInputsAndOutputs()); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector<Value, 8> ios(op.getInputsAndOutputs()); + SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; @@ -121,7 +121,7 @@ auto maps = loopToOperandRangesMaps(op); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector<Value, 8> ios(op.getInputsAndOutputs()); + SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; @@ -269,7 +269,7 @@ // Consumer consumes this view, `isStructurallyFusableProducer` also checks // whether it is a strict subview of the producer view. auto producedView = dependence.dependentOpView.view; - auto producerIdx = producer.getIndexOfOutput(producedView).getValue(); + auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); // `consumerIdx` and `producerIdx` exist by construction. LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() << " view: " << *producedView diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -96,7 +96,7 @@ permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end()); SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end()); - IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0)); + IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. // clang-format off @@ -114,7 +114,7 @@ assert(nPar == allIvs.size()); auto ivs = SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar); - IndexedValueType O(fillOp.getOutput(0)); + IndexedValueType O(fillOp.getOutputBuffer(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. nPar > 0 ? O(ivs) = ValueHandle(fillOp.value()) @@ -129,7 +129,7 @@ assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), - C(dotOp.getOutput(0)); + C(dotOp.getOutputBuffer(0)); // Emit scalar form. C() = C() + A(r_i) * B(r_i); } @@ -143,7 +143,7 @@ assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), - C(matvecOp.getOutput(0)); + C(matvecOp.getOutputBuffer(0)); // Emit scalar form. C(i) = C(i) + A(i, r_j) * B(r_j); } @@ -157,7 +157,7 @@ assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), - C(matmulOp.getOutput(0)); + C(matmulOp.getOutputBuffer(0)); // Emit scalar form. C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); } @@ -235,7 +235,8 @@ for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); + indexedValues[nInputs + i] = + std_load(genericOp.getOutputBuffer(i), indexing); } auto funcOp = genericOp.getFunction(); @@ -248,7 +249,7 @@ for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); + std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing); } return; } @@ -271,8 +272,8 @@ for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), - indexing); + std_store(map.lookup(yieldOp->getOperand(i)), + genericOp.getOutputBuffer(i), indexing); } } }; @@ -337,7 +338,7 @@ ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nLoops + nInputs + i] = - std_load(indexedGenericOp.getOutput(i), indexing); + std_load(indexedGenericOp.getOutputBuffer(i), indexing); } if (auto funcOp = indexedGenericOp.getFunction()) { @@ -349,7 +350,7 @@ for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), indexedGenericOp.getOutput(i), + std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i), indexing); } return; @@ -374,7 +375,7 @@ ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), - indexedGenericOp.getOutput(i), indexing); + indexedGenericOp.getOutputBuffer(i), indexing); } } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -171,7 +171,7 @@ return false; return true; }; - if (!llvm::all_of(genericOp.getInputsAndOutputs(), + if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(), isStaticMemRefWithIdentityLayout)) return failure(); return success(); @@ -195,7 +195,7 @@ using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>; auto vA = std_load(vector_type_cast(genericOp.getInput(0))); auto vB = std_load(vector_type_cast(genericOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0)); + auto vectorMemRefC = vector_type_cast(genericOp.getOutputBuffer(0)); auto vC = std_load(vectorMemRefC); auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(), genericOp.iterator_types()); @@ -262,7 +262,7 @@ // Transformation applies to buffers only. if (!linOp || !linOp.hasBufferSemantics()) return failure(); - if (llvm::none_of(linOp.getInputsAndOutputs(), [](Value v) { + if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) { return isa_and_nonnull<SubViewOp>(v.getDefiningOp()); })) return failure(); @@ -280,7 +280,7 @@ LinalgOp linOp = cast<LinalgOp>(op); SetVector<Value> subViews; - for (auto it : linOp.getInputsAndOutputs()) + for (auto it : linOp.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -164,7 +164,7 @@ SmallVector<std::pair<Value, Value>, 8> writebackViews; writebackViews.reserve(subViews.size()); unsigned promotedIdx = 0; - for (auto view : op.getInputsAndOutputs()) { + for (auto view : op.getInputsAndOutputBuffers()) { if (subViews.count(view) != 0) { opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); writebackViews.emplace_back(std::make_pair( @@ -187,7 +187,7 @@ // WARNING: MUST use the old op to determine whether the operand view is an // output. bool isOutput = - op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue(); + op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue(); if (isOutput) copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first); } @@ -207,7 +207,7 @@ // nothing. SetVector<Value> subViews; OpBuilder b(op); - for (auto it : op.getInputsAndOutputs()) + for (auto it : op.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -254,7 +254,7 @@ SmallVector<Value, 4> res; res.reserve(op->getNumOperands()); - auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin(); + auto viewIteratorBegin = linalgOp.getInputsAndOutputBuffers().begin(); for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs(); ++viewIndex) { Value view = *(viewIteratorBegin + viewIndex); diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -68,7 +68,7 @@ // ----- func @generic_exactly_2_views(%arg0: memref<f32>) { - // expected-error @+1 {{op expected exactly 2 view operands}} + // expected-error @+1 {{op expected exactly 2 inputs and buffer operands}} linalg.generic { args_in = 1, args_out = 1, @@ -96,7 +96,7 @@ func @foo() { return } func @generic_mismatched_num_arguments(%arg0: memref<f32>) { - // expected-error @+1 {{op expected fun arguments to match number of views}} + // expected-error @+1 {{op expected fun arguments to match number of operands}} linalg.generic { args_in = 0, args_out = 1, @@ -111,7 +111,7 @@ func @foo(%0: i32) { return } func @generic_mismatched_num_returns(%arg0: memref<f32>) { - // expected-error @+1 {{op expected fun results to match number of output views}} + // expected-error @+1 {{op expected fun results(0) to match number of outputs(1)}} linalg.generic { args_in = 0, args_out = 1, @@ -123,6 +123,36 @@ // ----- +func @foo(%0: i32, %1: i32, %2: i32) { return } + +func @generic_mismatched_num_returns(%0: memref<i32>, %1: memref<f32>) { + // expected-error @+1 {{op expected fun argument 2 of the same type as elemental type 'f32' of input 2}} + linalg.generic { + args_in = 3, + args_out = 0, + fun = @foo, + indexing_maps = [ () -> (0) ], + iterator_types = [] + } %0, %1, %1: memref<i32>, memref<f32>, memref<f32> +} + +// ----- + +func @foo(%0: i32, %1: i32, %2: f32) -> i32 { return %1: i32} + +func @generic_mismatched_num_returns(%0: memref<i32>, %1: memref<f32>) { + // expected-error @+1 {{op expected fun result 1 of the same type as elemental type 'f32' of output 1}} + linalg.generic { + args_in = 2, + args_out = 1, + fun = @foo, + indexing_maps = [ () -> (0) ], + iterator_types = [] + } %0, %0, %1: memref<i32>, memref<i32>, memref<f32> +} + +// ----- + func @foo(%0: i32) -> i32 { return %0: i32 } func @generic_symbol_in_map(%arg0: memref<i32>) { @@ -189,7 +219,7 @@ } func @generic_fun_arg_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) { - // expected-error @+1 {{op expected fun argument 0 of the same type as elemental type 'f32' of view 0}} + // expected-error @+1 {{op expected fun argument 1 of the same type as elemental type 'f32' of input 1}} linalg.generic { args_in = 0, args_out = 1, @@ -207,7 +237,7 @@ } func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) { - // expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'f32' of view 0}} + // expected-error @+1 {{op expected fun result 1 of the same type as elemental type 'f32' of output 1}} linalg.generic { args_in = 0, args_out = 1, @@ -257,7 +287,7 @@ // ----- func @generic_mismatched_num_arguments(%arg0: memref<f32>) { - // expected-error @+1 {{op expected number of block arguments to match number of views}} + // expected-error @+1 {{op expected number of block arguments to match number of operands}} linalg.generic { args_in = 0, args_out = 1, @@ -271,7 +301,7 @@ // ----- func @generic_block_arg_type(%arg0: memref<f32>) { - // expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: 'memref<f32>'}} + // expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output operand: 'memref<f32>'}} linalg.generic { args_in = 0, args_out = 1, @@ -285,7 +315,7 @@ // ----- func @indexed_generic_block_arg_count(%arg0: memref<f32>) { - // expected-error @+1 {{op expected number of block arguments to match number of views + number of loops}} + // expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -299,7 +329,7 @@ // ----- func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) { - // expected-error @+1 {{op expected block argument 0 to be of IndexType}} + // expected-error @+1 {{op expected block argument 0 to be an index}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -313,7 +343,7 @@ // ----- func @indexed_generic_block_arg_type(%arg0: memref<f32>) { - // expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output view: 'memref<f32>'}} + // expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output operand: 'memref<f32>'}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -330,7 +360,7 @@ return %f : f32 } func @indexed_generic_fun_arg_count(%arg0: memref<f32>) { - // expected-error @+1 {{op expected fun arguments to match number of views + number of loops}} + // expected-error @+1 {{op expected fun arguments to match number of loops + number of operands}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -346,7 +376,7 @@ return %val : f32 } func @indexed_generic_fun_induction_var_arg_type(%arg0: memref<f32>) { - // expected-error @+1 {{op expected fun argument 0 to be of IndexType}} + // expected-error @+1 {{op expected fun argument 0 to be an index}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -362,7 +392,7 @@ return %val : i1 } func @indexed_generic_fun_arg_type(%arg0: memref<f32>) { - // expected-error @+1 {{op expected fun argument 1 of the same type as elemental type 'f32' of view 0}} + // expected-error @+1 {{op expected fun argument 2 of the same type as elemental type 'f32' of input 1}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -378,7 +408,7 @@ return %val, %val : i1, i1 } func @indexed_generic_fun_result_count(%arg0: memref<f32>) { - // expected-error @+1 {{op expected fun results to match number of output views}} + // expected-error @+1 {{op expected fun results to match number of outputs}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -395,7 +425,7 @@ return %val_float : f32 } func @indexed_generic_fun_result_count(%arg0: memref<i32>) { - // expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'i32' of view 0}} + // expected-error @+1 {{op expected fun result 1 of the same type as elemental type 'i32' of output 1}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -408,7 +438,7 @@ // ----- func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) { - // expected-error @+9 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}} + // expected-error @+9 {{type of yield operand 0 ('i1') doesn't match the element type of the enclosing linalg.generic op('f32')}} linalg.generic { args_in = 0, args_out = 1, @@ -438,36 +468,6 @@ // ----- -func @generic_result_tensor_count(%arg0: memref<?xf32, (i)[off]->(off + i)>) { - // expected-error @+1 {{op expected #output tensor operands (0) to match #results (1)}} - %0 = linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ (i) -> (i) ], - iterator_types = ["parallel"] - } %arg0 { - ^bb(%i: f32): - linalg.yield %i: f32 - }: memref<?xf32, (i)[off]->(off + i)> -> tensor<?xf32> -} - -// ----- - -func @generic_result_tensor_type(%arg0: tensor<?xf32>) { - // expected-error @+1 {{op result #0 must be 'tensor<?xf32>', but got 'tensor<?x?xf32>'}} - %0 = linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ (i) -> (i) ], - iterator_types = ["parallel"] - } %arg0 { - ^bb(%i: f32): - linalg.yield %i: f32 - }: tensor<?xf32> -> tensor<?x?xf32> -} - -// ----- - func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) { // expected-error @+1 {{'linalg.dot' op expected 3 or more operands}} linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -157,23 +157,23 @@ // CHECK-LABEL: func @generic_with_tensor_input // CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor<?x?xvector<3x4xi4>>, memref<?x?x?xf32, #[[strided3D]]> -func @generic_with_tensor_output(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %arg1: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) { - %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, tensor<?x?x?xf32> -> tensor<?x?x?xf32> - return %0 : tensor<?x?x?xf32> +#trait2 = { + args_in = 2, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel", "parallel"], + fun = @foo, + library_call = "some_external_function_name_1" } -// CHECK-LABEL: func @generic_with_tensor_output -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, tensor<?x?x?xf32> -> tensor<?x?x?xf32> -// CHECK: return {{.*}} : tensor<?x?x?xf32> - func @generic_with_tensor_input_and_output(%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) { - %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32> + %0 = linalg.generic #trait2 %arg0, %arg1 {foo = 1} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32> return %0 : tensor<?x?x?xf32> } // CHECK-LABEL: func @generic_with_tensor_input_and_output -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32> +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32> // CHECK: return {{.*}} : tensor<?x?x?xf32> -#trait2 = { +#trait3 = { args_in = 1, args_out = 1, indexing_maps = #accesses, @@ -181,7 +181,7 @@ library_call = "some_external_function_name_2" } func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) { - linalg.generic #trait2 %arg0, %arg1 { + linalg.generic #trait3 %arg0, %arg1 { ^bb(%a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 } {foo = 1}: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> @@ -194,7 +194,7 @@ // CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]> func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) { - linalg.indexed_generic #trait2 %arg0, %arg1 { + linalg.indexed_generic #trait3 %arg0, %arg1 { ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 } {foo = 1}: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>