diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -35,38 +35,9 @@ // interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let methods = [ - InterfaceMethod< - "Query the number of inputs from the current operation.", - "unsigned", "getNumInputs" - >, - InterfaceMethod< - "Query the number of outputs from the current operation.", - "unsigned", "getNumOutputs" - >, - InterfaceMethod< - "Query the number of inputs and outputs from the current operation.", - "unsigned", "getNumInputsAndOutputs" - >, - InterfaceMethod< - "Query the input operands from the current operation.", - "Operation::operand_range", "getInputs" - >, - InterfaceMethod< - "Query the output operands from the current operation.", - "Operation::operand_range", "getOutputs" - >, - InterfaceMethod< - "Query the input and output operands from the current operation.", - "Operation::operand_range", "getInputsAndOutputs" - >, - InterfaceMethod< - "Query the iterator types attribute within the current operation.", - "ArrayAttr", "iterator_types" - >, - InterfaceMethod< - "Query the indexing maps attribute within the current operation.", - "ArrayAttr", "indexing_maps" - >, + //========================================================================// + // Loop types handling. + //========================================================================// InterfaceMethod< "Query the number of parallel loops within the current operation.", "unsigned", "getNumParallelLoops" @@ -82,40 +53,98 @@ InterfaceMethod< "Query the number of loops within the current operation.", "unsigned", "getNumLoops">, + + //========================================================================// + // Input arguments handling. + //========================================================================// + InterfaceMethod< + "Query the number of inputs from the current operation.", + "unsigned", "getNumInputs" + >, InterfaceMethod<"Query the input view at the given index.", "Value ", "getInput", (ins "unsigned":$i) >, - InterfaceMethod<"Query the output view at the given index.", - "Value ", "getOutput", (ins "unsigned":$i) - >, InterfaceMethod<[{ Return the index of the given input value `v`, or `None` if the value is not an input. }], "llvm::Optional", "getIndexOfInput", (ins "Value ":$v) >, - InterfaceMethod<[{ - Query the index of the given view value, or `None` if the value is not - a view. - }], - "llvm::Optional", "getIndexOfOutput", (ins "Value ":$view) + InterfaceMethod< + "Query the input operands from the current operation.", + "Operation::operand_range", "getInputs" >, InterfaceMethod<[{ Query the type of the input shape at the given index. }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ - Query the type of the output view at the given index. - }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, - InterfaceMethod<[{ - Query whether the op has only MemRef input and outputs. - }], "bool", "hasBufferSemantics">, - InterfaceMethod<[{ Query the subset of input operands that are of ranked tensor type. }], "SmallVector", "getInputTensorTypes">, + + + //========================================================================// + // Output arguments handling. + //========================================================================// + InterfaceMethod< + "Query the number of outputs from the current operation.", + "unsigned", "getNumOutputs" + >, + InterfaceMethod<"Query the output buffer at the given index.", + "Value ", "getOutputBuffer", (ins "unsigned":$i) + >, InterfaceMethod<[{ - Query the subset of output operands that are of ranked tensor type. + Query the index of the given buffer value, or `None` if the value is not + part of the output buffers. + }], + "llvm::Optional", "getIndexOfOutputBuffer", (ins "Value ":$view) + >, + InterfaceMethod<[{ + Query the type of the output buffer at the given index. + }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Query the results that are of ranked tensor type. }], "SmallVector", "getOutputTensorTypes">, + InterfaceMethod< + "Query the output buffers (operands) from the current operation.", + "Operation::operand_range", "getOutputBuffers" + >, + //========================================================================// + // Input and Output arguments handling. + //========================================================================// + InterfaceMethod< + "Return the number of inputs and outputs, irrespective of their buffer " + "or tensor type.", + "unsigned", "getNumInputsAndOutputs" + >, + InterfaceMethod< + "Return the number of inputs, irrespective of their buffer or tensor " + "type, and output buffers", + "unsigned", "getNumInputsAndOutputBuffers" + >, + InterfaceMethod< + "Return the range over inputs (irrespective of type) and output buffers.", + "Operation::operand_range", "getInputsAndOutputBuffers" + >, + + //========================================================================// + // Other interface methods. + //========================================================================// + InterfaceMethod< + "Query the iterator types attribute within the current operation.", + "ArrayAttr", "iterator_types" + >, + InterfaceMethod< + "Query the indexing maps attribute within the current operation.", + "ArrayAttr", "indexing_maps" + >, + InterfaceMethod<[{ + Query whether the op has only MemRef input and outputs. + }], "bool", "hasBufferSemantics">, + + //========================================================================// + // Other static interface methods. + //========================================================================// StaticInterfaceMethod<[{ Create an operation of the current type with the given location, operands, and attributes. @@ -128,9 +157,6 @@ attributes); }] >, - - /// Clone an operation with the given location and operands. This is used to - /// abstract away the optional underlying region creation. InterfaceMethod<[{ Clone the current operation with the given location and operands. This is used to abstract away the optional underlying region creation. @@ -536,22 +562,26 @@ mixing input and output ranked tensor values with input and output memrefs. ```mlir - %1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} : + %C = linalg.generic #trait_attribute %A, %B {other-attributes} : tensor, - memref, - tensor + memref -> (tensor) ``` - In this case, the number of return values must match the number of output - tensor arguments. The semantics is that the `linalg.generic` op - produces (i.e. allocates and fills) its return values. + In this case, the number of outputs (args_out) must match the sum of (1) the + number of output buffer operands and (2) the number of tensor return values. + The semantics is that the `linalg.indexed_generic` op produces (i.e. + allocates and fills) its tensor return values. + Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create - control flow around linalg.generic operations are not expected to mix with - tensors because SSA values do not escape naturally. Still, transformations - and rewrites that take advantage of tensor SSA values are expected to be - useful and will be added in the near future. + transformations can be applied. Such legalization moves tensor return values + into output buffer operands and updates the region arguments accordingly. + + Transformations that create control-flow around linalg.indexed_generic + operations are not expected to work with tensors because SSA values do not + escape naturally. Still, transformations and rewrites that take advantage of + tensor SSA values are expected to be useful and will be added in the near + future. }]; let verifier = [{ return ::verify(*this); }]; } @@ -659,22 +689,26 @@ memrefs. ```mlir - %1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} + %C = linalg.indexed_generic #trait_attribute %A, %B {other-attributes} : tensor, - memref, - tensor + memref -> (tensor) ``` - In this case, the number of return values must match the number of output - tensor arguments. The semantics is that the `linalg.indexed_generic` op - produces (i.e. allocates and fills) its return values. + In this case, the number of outputs (args_out) must match the sum of (1) the + number of output buffer operands and (2) the number of tensor return values. + The semantics is that the `linalg.indexed_generic` op produces (i.e. + allocates and fills) its return values. + Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create - control flow around linalg.generic operations are not expected to mix with - tensors because SSA values do not escape naturally. Still, transformations - and rewrites that take advantage of tensor SSA values are expected to be - useful and will be added in the near future. + transformations can be applied. Such legalization moves tensor return values + into output buffer operands and updates the region argument accordingly. + + Transformations that create control-flow around linalg.indexed_generic + operations are not expected to work with tensors because SSA values do not + escape naturally. Still, transformations and rewrites that take advantage of + tensor SSA values are expected to be useful and will be added in the near + future. }]; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -58,16 +58,47 @@ class StructuredOpTraits : public OpTrait::TraitBase { private: - /// Return the number of inputs. For internal use only. + /// Return the number of inputs, irrespective of their buffer or tensor type. + /// For internal use only. unsigned nInputs() { return cast(this->getOperation()).getNumInputs(); } - /// Return the number of outputs. For internal use only. + /// Return the number of outputs, irrespective of their buffer or tensor type. + /// For internal use only. unsigned nOutputs() { return cast(this->getOperation()).getNumOutputs(); } public: + //==========================================================================// + // Loop types handling. + //==========================================================================// + unsigned getNumParallelLoops() { + return getNumIterators( + getParallelIteratorTypeName(), + cast(this->getOperation()).iterator_types()); + } + unsigned getNumReductionLoops() { + return getNumIterators( + getReductionIteratorTypeName(), + cast(this->getOperation()).iterator_types()); + } + unsigned getNumWindowLoops() { + return getNumIterators( + getWindowIteratorTypeName(), + cast(this->getOperation()).iterator_types()); + } + unsigned getNumLoops() { + return getNumIterators( + cast(this->getOperation()).iterator_types()); + } + + //==========================================================================// + // Input arguments handling. + //==========================================================================// + // The `i^th` input argument is always the `i^th` operand regardless of + // whether we have tensors or buffers. + // /// Return the `i`-th input value. Value getInput(unsigned i) { assert(i < nInputs()); @@ -90,28 +121,6 @@ auto range = this->getOperation()->getOperands(); return {range.begin(), range.begin() + nInputs()}; } - /// Return the `i`-th output. - Value getOutput(unsigned i) { - return this->getOperation()->getOperand(nInputs() + i); - } - /// Return the index of `value` in the list of output values if found, - /// llvm::None otherwise. - Optional getIndexOfOutput(Value value) { - auto it = llvm::find(getOutputs(), value); - if (it != getOutputs().end()) - return it - getOutputs().begin(); - return llvm::None; - } - /// Return the `i`-th output buffer type. - ShapedType getOutputShapedType(unsigned i) { - return getOutput(i).getType().template cast(); - } - /// Query whether the op has only MemRef input and outputs. - bool hasBufferSemantics() { - return this->getOperation()->getNumResults() == 0 && - llvm::all_of(getInputsAndOutputs(), - [](Value v) { return v.getType().isa(); }); - } /// Query the subset of input operands that are of ranked tensor type. SmallVector getInputTensorTypes() { SmallVector res; @@ -120,53 +129,97 @@ res.push_back(t); return res; } - /// Query the subset of output operands that are of ranked tensor type. + + //==========================================================================// + // Output arguments handling. + //==========================================================================// + // The `i^th` output argument is an operand (resp. a return value) iff it is + // a value of buffer type (resp. a return value of tensor type). + + /// Return the `i`-th output, asserts that this is a buffer operand and not + /// a tensor result. + Value getOutputBuffer(unsigned i) { + assert(i + this->getOperation()->getNumResults() < nOutputs() && + "overflowing output buffer index"); + return this->getOperation()->getOperand(nInputs() + i); + } + /// Return the index of `value` in the list of output buffers if found, + /// llvm::None otherwise. + Optional getIndexOfOutputBuffer(Value value) { + auto it = llvm::find(getOutputBuffers(), value); + if (it != getOutputBuffers().end()) + return it - getOutputBuffers().begin(); + return llvm::None; + } + /// Return the `i`-th output buffer type. + MemRefType getOutputBufferType(unsigned i) { + return getOutputBuffer(i).getType().template cast(); + } + /// Return the `i`-th output shaped type, irrespective of buffer of tensor + /// type. + ShapedType getOutputShapedType(unsigned i) { + return getShapedType(i + nInputs()); + } + /// Query the subset of results that are of ranked tensor type. SmallVector getOutputTensorTypes() { SmallVector res; - for (Type type : getOutputs().getTypes()) - if (auto t = type.template dyn_cast()) - res.push_back(t); + for (Type type : this->getOperation()->getResults().getTypes()) + res.push_back(type.template cast()); return res; } /// Return the range over outputs. - Operation::operand_range getOutputs() { + Operation::operand_range getOutputBuffers() { auto range = this->getOperation()->getOperands(); return {range.begin() + nInputs(), - range.begin() + getNumInputsAndOutputs()}; + range.begin() + getNumInputsAndOutputBuffers()}; } - /// Return the number of inputs and outputs. + + //==========================================================================// + // Input and Output arguments handling. + //==========================================================================// + /// Return the number of inputs and outputs, irrespective of their buffer or + /// tensor type. unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } - /// Return the `i`-th buffer type. - ShapedType getShapedType(unsigned i) { - return (i < nInputs()) ? getInputShapedType(i) - : getOutputShapedType(i - nInputs()); - } - /// Return the range over inputs and outputs. - Operation::operand_range getInputsAndOutputs() { + /// Return the number of inputs, irrespective of their buffer or tensor type, + /// and output buffers. + unsigned getNumInputsAndOutputBuffers() { + assert(this->getOperation()->getNumResults() <= nOutputs()); + return nInputs() + nOutputs() - this->getOperation()->getNumResults(); + } + /// Return the range over inputs (irrespective of type) and output buffers. + Operation::operand_range getInputsAndOutputBuffers() { auto range = this->getOperation()->getOperands(); - return {range.begin(), range.begin() + getNumInputsAndOutputs()}; + return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; } - unsigned getNumParallelLoops() { - return getNumIterators( - getParallelIteratorTypeName(), - cast(this->getOperation()).iterator_types()); - } - unsigned getNumReductionLoops() { - return getNumIterators( - getReductionIteratorTypeName(), - cast(this->getOperation()).iterator_types()); - } - unsigned getNumWindowLoops() { - return getNumIterators( - getWindowIteratorTypeName(), - cast(this->getOperation()).iterator_types()); + /// Return the `i`-th shaped type, there are 3 cases: + /// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise + /// 2. if `i < getNumInputsAndOutputBuffers()` then return the + /// `getOutputBufferType(i - nInputs())`; otherwise + /// 3. return the `i - getNumInputsAndOutputBuffers()` result type. + ShapedType getShapedType(unsigned i) { + if (i < nInputs()) + return getInputShapedType(i); + if (i < getNumInputsAndOutputBuffers()) + return getOutputBufferType(i - nInputs()).template cast(); + return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()] + .template cast(); } - unsigned getNumLoops() { - return getNumIterators( - cast(this->getOperation()).iterator_types()); + + //==========================================================================// + // Other interface methods. + //==========================================================================// + /// Query whether the op has only buffer inputs and no returns. + bool hasBufferSemantics() { + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(getInputs(), + [](Value v) { return v.getType().isa(); }); } + + //==========================================================================// + // Other static interface methods. + //==========================================================================// static LogicalResult verifyTrait(Operation *op) { - auto nOperands = cast(op).getNumInputsAndOutputs(); + auto nOperands = cast(op).getNumInputsAndOutputBuffers(); if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) return failure(); return success(); diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -119,7 +119,7 @@ template SmallVector getViewSizes(ConcreteOp linalgOp) { SmallVector res; - for (auto v : linalgOp.getInputsAndOutputs()) { + for (auto v : linalgOp.getInputsAndOutputBuffers()) { MemRefType t = v.getType().template cast(); for (unsigned i = 0; i < t.getRank(); ++i) res.push_back(edsc::intrinsics::dim(v, i)); diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -139,7 +139,11 @@ } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { - for (auto srcView : src.getOutputs()) { // W + assert(src.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(dst.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + for (auto srcView : src.getOutputBuffers()) { // W // RAW graph for (auto dstView : dst.getInputs()) { // R if (aliases.alias(srcView, dstView)) { // if alias, fill RAW @@ -149,7 +153,7 @@ } } // WAW graph - for (auto dstView : dst.getOutputs()) { // W + for (auto dstView : dst.getOutputBuffers()) { // W if (aliases.alias(srcView, dstView)) { // if alias, fill WAW addDependenceElem(DependenceType::WAW, LinalgOpView{src.getOperation(), srcView}, @@ -167,7 +171,7 @@ } } // WAR graph - for (auto dstView : dst.getOutputs()) { // W + for (auto dstView : dst.getOutputBuffers()) { // W if (aliases.alias(srcView, dstView)) { // if alias, fill WAR addDependenceElem(DependenceType::WAR, LinalgOpView{src.getOperation(), srcView}, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -112,19 +112,20 @@ static LogicalResult verifyBlockArgs(GenericOpType op, Block &block); template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (block.getNumArguments() != nViews) - return op.emitOpError( - "expected number of block arguments to match number of views"); + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands) + return op.emitOpError("expected number of block arguments to match number " + "of operands"); - for (unsigned i = 0; i < nViews; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + auto nInputViews = op.getNumInputs(); + for (unsigned i = 0; i < nOperands; ++i) { auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(i).getType()) return op.emitOpError("expected block argument ") - << i << " of the same type as elemental type of " + << (i + 1) << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -132,27 +133,28 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (block.getNumArguments() != nViews + nLoops) + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands + nLoops) return op.emitOpError( - "expected number of block arguments to match number of views + " + "expected number of block arguments to match number of operands + " "number of loops"); - for (unsigned i = 0; i < nLoops; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + for (unsigned i = 0; i < nLoops; ++i) if (!block.getArgument(i).getType().isIndex()) return op.emitOpError("expected block argument ") - << i << " to be of IndexType"; - } + << (i + 1) << " to be an index"; - for (unsigned i = 0; i < nViews; ++i) { + for (unsigned i = 0; i < nOperands; ++i) { unsigned memrefArgIndex = i + nLoops; auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(memrefArgIndex).getType()) return op.emitOpError("expected block argument ") - << memrefArgIndex << " of the same type as elemental type of " + << (memrefArgIndex + 1) + << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -160,70 +162,74 @@ template static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType); +template +LogicalResult verifyFuncArgsGeneric(GenericOpType op, FunctionType funType) { + auto res = verifyFuncArgs(op, funType); + if (failed(res)) + return res; + + auto nInputs = op.getNumInputs(); + auto nOutputs = op.getNumOutputs(); + // linalg.generic output element types are exactly the function results. + for (unsigned idx = 0; idx < nOutputs; ++idx) { + ShapedType shapedType = op.getShapedType(nInputs + idx); + if (funType.getResult(idx) != shapedType.getElementType()) + return op.emitOpError("expected function result ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of output " << (idx + 1); + } + return success(); +} + template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (funType.getNumInputs() != nViews) - return op.emitOpError("expected fun arguments to match number of views"); - if (funType.getNumResults() != op.getNumOutputs()) + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands) return op.emitOpError( - "expected fun results to match number of output views"); - - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(idx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << idx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + "expected function arguments to match number of operands"); + if (funType.getNumResults() != op.getNumOutputs()) + return op.emitOpError("expected function results(") + << funType.getNumResults() << ") to match number of outputs(" + << op.getNumOutputs() << ")"; + + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of operand " << (idx + 1); } + return success(); } template <> LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) { auto nLoops = op.getNumLoops(); - auto nInputViews = op.getNumInputs(); auto nOutputs = op.getNumOutputs(); - auto nViews = op.getNumInputsAndOutputs(); - if (funType.getNumInputs() != nViews + nLoops) - return op.emitOpError( - "expected fun arguments to match number of views + number of loops"); + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands + nLoops) + return op.emitOpError("expected function arguments to match number of " + "loops + number of operands"); if (funType.getNumResults() != nOutputs) return op.emitOpError( - "expected fun results to match number of output views"); - for (unsigned i = 0; i < nLoops; ++i) { + "expected function results to match number of outputs"); + for (unsigned i = 0; i < nLoops; ++i) if (!funType.getInput(i).isIndex()) - return op.emitOpError("expected fun argument ") - << i << " to be of IndexType"; - } - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto funIdx = nLoops + idx; - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(funIdx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << funIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + return op.emitOpError("expected function argument ") + << (i + 1) << " to be an index"; + + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx + nLoops) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (idx + nLoops + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of input " << (idx + 1); } + return success(); } @@ -231,9 +237,11 @@ static LogicalResult verifyGenericOp(GenericOpType op) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (nViews != llvm::size(op.views())) - return op.emitOpError("expected exactly ") << nViews << " view operands"; + auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); + if (nInputsAndOutputBuffers != llvm::size(op.views())) + return op.emitOpError("expected exactly ") + << nInputsAndOutputBuffers + << " inputs (tensor or buffer) and output buffer operands"; auto ®ion = op.region(); auto funOp = op.getFunction(); @@ -246,8 +254,8 @@ } else { if (!funOp || !funOp.getType()) return op.emitOpError( - "expected fun attribute to refer to a defined symbol"); - if (failed(verifyFuncArgs(op, funType))) + "expected function attribute to refer to a defined symbol"); + if (failed(verifyFuncArgsGeneric(op, funType))) return failure(); } @@ -287,22 +295,6 @@ return op.emitOpError("expected the concatenation of maps in indexing_map " "to be invertible"); - auto outputTensorTypes = op.getOutputTensorTypes(); - if (outputTensorTypes.size() != op.getNumResults()) - return op.emitOpError("expected #output tensor operands (") - << outputTensorTypes.size() << ") to match #results (" - << op.getNumResults() << ")"; - - unsigned index = 0; - for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) { - auto resTy = std::get<0>(it); - auto outOpTy = std::get<1>(it); - if (resTy != outOpTy) - return op.emitOpError("result #") - << index << " must be " << outOpTy << ", but got " << resTy; - ++index; - } - return success(); } @@ -731,17 +723,20 @@ template static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { // The operand number and types must match the view element types. - auto nOutputViews = genericOp.getNumOutputs(); - if (op.getNumOperands() != nOutputViews) - return op.emitOpError("expected ") - << nOutputViews << " operand to match enclosing linalg.generic op"; + auto nOutputs = genericOp.getNumOutputs(); + if (op.getNumOperands() != nOutputs) + return op.emitOpError("expected number of yield values (") + << nOutputs << ") to match the number of operands of the enclosing " + << "linalg.generic op (" << op.getNumOperands() << ")"; - for (unsigned i = 0; i != nOutputViews; ++i) { + for (unsigned i = 0; i != nOutputs; ++i) { auto elementType = genericOp.getOutputShapedType(i).getElementType(); if (op.getOperand(i).getType() != elementType) - return op.emitOpError("type of return operand ") - << i << " (" << op.getOperand(i).getType() - << ") doesn't match view element type (" << elementType << ")"; + return op.emitOpError("type of yield operand ") + << (i + 1) << " (" << op.getOperand(i).getType() + << ") doesn't match " + << "the element type of the enclosing linalg.generic op (" + << elementType << ")"; } return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -67,12 +67,13 @@ // to the `loopRanges` in order to obtain view ranges. static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, ArrayRef loopRanges) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto maps = loopToOperandRangesMaps(op); SmallVector clonedViews; clonedViews.reserve(op.getNumInputsAndOutputs()); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector ios(op.getInputsAndOutputs()); + SmallVector ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; @@ -118,10 +119,11 @@ // they must agree by construction (i.e. have the same size) and we just return // the first one. static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto maps = loopToOperandRangesMaps(op); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector ios(op.getInputsAndOutputs()); + SmallVector ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; @@ -144,6 +146,10 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, unsigned producerIdx, OperationFolder *folder) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto subView = dyn_cast_or_null( consumer.getInput(consumerIdx).getDefiningOp()); auto slice = @@ -197,6 +203,10 @@ // Some of these will be lifted in the future with better analysis. static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, LinalgOp consumer) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); if (producer.getNumOutputs() != 1) { LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); return false; @@ -217,6 +227,10 @@ LinalgOp consumer, Value consumedView, LinalgOp producer) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); // Make some simple structural checks that alleviate the need for more // complex analyses. if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { @@ -236,6 +250,10 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) return false; // Check for any fusion-preventing dependence to any view read/written that @@ -252,6 +270,8 @@ Optional mlir::linalg::fuseProducerOf( OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, const LinalgDependenceGraph &graph, OperationFolder *folder) { + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); LLVM_DEBUG(dbgs() << "\nStart examining consumer: " << *consumer.getOperation()); for (auto dependence : graph.getDependencesInto( @@ -268,7 +288,7 @@ // Consumer consumes this view, `isStructurallyFusableProducer` also checks // whether it is a strict subview of the producer view. auto producedView = dependence.dependentOpView.view; - auto producerIdx = producer.getIndexOfOutput(producedView).getValue(); + auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); // `consumerIdx` and `producerIdx` exist by construction. LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() << " view: " << producedView @@ -309,7 +329,10 @@ // Save original Linalg ops, we only want to make a pass over those. SmallVector linalgOps; - f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); + f.walk([&](LinalgOp op) { + if (op.hasBufferSemantics()) + linalgOps.push_back(op); + }); Aliases aliases; LinalgDependenceGraph G(aliases, linalgOps); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -90,6 +90,8 @@ class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { + assert(copyOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto nPar = copyOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto inputIvs = @@ -98,7 +100,7 @@ permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); SmallVector iivs(inputIvs.begin(), inputIvs.end()); SmallVector oivs(outputIvs.begin(), outputIvs.end()); - IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0)); + IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. // clang-format off @@ -112,11 +114,13 @@ class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { + assert(fillOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto ivs = SmallVector(allIvs.begin(), allIvs.begin() + nPar); - IndexedValueType O(fillOp.getOutput(0)); + IndexedValueType O(fillOp.getOutputBuffer(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. nPar > 0 ? O(ivs) = ValueHandle(fillOp.value()) @@ -128,10 +132,12 @@ class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { + assert(dotOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), - C(dotOp.getOutput(0)); + C(dotOp.getOutputBuffer(0)); // Emit scalar form. C() = C() + A(r_i) * B(r_i); } @@ -142,10 +148,12 @@ public: static void emitScalarImplementation(ArrayRef allIvs, MatvecOp matvecOp) { + assert(matvecOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), - C(matvecOp.getOutput(0)); + C(matvecOp.getOutputBuffer(0)); // Emit scalar form. C(i) = C(i) + A(i, r_j) * B(r_j); } @@ -156,10 +164,12 @@ public: static void emitScalarImplementation(ArrayRef allIvs, MatmulOp matmulOp) { + assert(matmulOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), - C(matmulOp.getOutput(0)); + C(matmulOp.getOutputBuffer(0)); // Emit scalar form. C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); } @@ -169,6 +179,8 @@ class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp) { + assert(convOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); auto maps = loopToOperandRangesMaps(convOp); @@ -219,6 +231,8 @@ public: static void emitScalarImplementation(ArrayRef allIvs, GenericOp genericOp) { + assert(genericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -237,7 +251,8 @@ for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); + indexedValues[nInputs + i] = + std_load(genericOp.getOutputBuffer(i), indexing); } auto funcOp = genericOp.getFunction(); @@ -250,7 +265,7 @@ for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); + std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing); } return; } @@ -273,8 +288,8 @@ for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), - indexing); + std_store(map.lookup(yieldOp->getOperand(i)), + genericOp.getOutputBuffer(i), indexing); } } }; @@ -314,6 +329,8 @@ public: static void emitScalarImplementation(ArrayRef allIvs, IndexedGenericOp indexedGenericOp) { + assert(indexedGenericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -339,7 +356,7 @@ ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nLoops + nInputs + i] = - std_load(indexedGenericOp.getOutput(i), indexing); + std_load(indexedGenericOp.getOutputBuffer(i), indexing); } if (auto funcOp = indexedGenericOp.getFunction()) { @@ -351,7 +368,7 @@ for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), indexedGenericOp.getOutput(i), + std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i), indexing); } return; @@ -376,7 +393,7 @@ ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), - indexedGenericOp.getOutput(i), indexing); + indexedGenericOp.getOutputBuffer(i), indexing); } } }; @@ -404,6 +421,8 @@ // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). auto linalgOp = cast(op); + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto invertedMap = inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); if (!invertedMap) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -93,6 +93,8 @@ Operation *consumerOp, Value consumedView, function_ref isaOpType) { LinalgOp consumer = dyn_cast(consumerOp); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); if (!consumer) return false; @@ -171,7 +173,7 @@ return false; return true; }; - if (!llvm::all_of(genericOp.getInputsAndOutputs(), + if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(), isStaticMemRefWithIdentityLayout)) return failure(); return success(); @@ -188,6 +190,8 @@ "DRR failure case must be a precondition"); auto genericOp = cast(op); + assert(genericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); edsc::ScopedContext scope(rewriter, op->getLoc()); using edsc::intrinsics::std_load; using edsc::intrinsics::std_store; @@ -195,7 +199,7 @@ using vector_type_cast = edsc::intrinsics::ValueBuilder; auto vA = std_load(vector_type_cast(genericOp.getInput(0))); auto vB = std_load(vector_type_cast(genericOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0)); + auto vectorMemRefC = vector_type_cast(genericOp.getOutputBuffer(0)); auto vC = std_load(vectorMemRefC); auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(), genericOp.iterator_types()); @@ -262,7 +266,7 @@ // Transformation applies to buffers only. if (!linOp || !linOp.hasBufferSemantics()) return failure(); - if (llvm::none_of(linOp.getInputsAndOutputs(), [](Value v) { + if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) { return isa_and_nonnull(v.getDefiningOp()); })) return failure(); @@ -279,8 +283,10 @@ "DRR failure case must be a precondition"); LinalgOp linOp = cast(op); + assert(linOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); SetVector subViews; - for (auto it : linOp.getInputsAndOutputs()) + for (auto it : linOp.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null(it.getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -155,6 +155,8 @@ SetVector subViews, bool dynamicBuffers, OperationFolder *folder) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); + // 1. Promote the specified views and use them in the new op. ScopedContext scope(b, op.getLoc()); auto promotedBufferAndViews = promoteSubViews( @@ -164,7 +166,7 @@ SmallVector, 8> writebackViews; writebackViews.reserve(subViews.size()); unsigned promotedIdx = 0; - for (auto view : op.getInputsAndOutputs()) { + for (auto view : op.getInputsAndOutputBuffers()) { if (subViews.count(view) != 0) { opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); writebackViews.emplace_back(std::make_pair( @@ -187,7 +189,7 @@ // WARNING: MUST use the old op to determine whether the operand view is an // output. bool isOutput = - op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue(); + op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue(); if (isOutput) copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first); } @@ -203,11 +205,14 @@ SmallVector toErase; OperationFolder folder(f.getContext()); f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) { + if (!op.hasBufferSemantics()) + return; + // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or // nothing. SetVector subViews; OpBuilder b(op); - for (auto it : op.getInputsAndOutputs()) + for (auto it : op.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null(it.getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -173,6 +173,7 @@ static void transformIndexedGenericOpIndices( OpBuilder &b, LinalgOp op, ArrayRef pivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto indexedGenericOp = dyn_cast(op.getOperation()); if (!indexedGenericOp) return; @@ -232,6 +233,8 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, ArrayRef ivs, ArrayRef tileSizes, ArrayRef viewSizes, OperationFolder *folder) { + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), [](Value v) { return !isZero(v); })) && @@ -254,7 +257,7 @@ SmallVector res; res.reserve(op->getNumOperands()); - auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin(); + auto viewIteratorBegin = linalgOp.getInputsAndOutputBuffers().begin(); for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs(); ++viewIndex) { Value view = *(viewIteratorBegin + viewIndex); @@ -309,6 +312,7 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, ArrayRef permutation, OperationFolder *folder) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); // 1. Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of // adjusting affine maps to account for missing dimensions. @@ -383,6 +387,7 @@ Optional mlir::linalg::tileLinalgOp( OpBuilder &b, LinalgOp op, ArrayRef tileSizes, ArrayRef permutation, OperationFolder *folder) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); if (tileSizes.empty()) return llvm::None; @@ -419,6 +424,8 @@ OpBuilder b(f); OperationFolder folder(f.getContext()); f.walk([tileSizes, &b, &folder](LinalgOp op) { + if (!op.hasBufferSemantics()) + return; auto opLoopsPair = tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder); // If tiling occurred successfully, erase old op. diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -68,7 +68,7 @@ // ----- func @generic_exactly_2_views(%arg0: memref) { - // expected-error @+1 {{op expected exactly 2 view operands}} + // expected-error @+1 {{op expected exactly 2 inputs (tensor or buffer) and output buffer operands}} linalg.generic { args_in = 1, args_out = 1, @@ -81,7 +81,7 @@ // ----- func @generic_undefined_fun(%arg0: memref) { - // expected-error @+1 {{op expected fun attribute to refer to a defined symbol}} + // expected-error @+1 {{op expected function attribute to refer to a defined symbol}} linalg.generic { args_in = 1, args_out = 1, @@ -96,7 +96,7 @@ func @foo() { return } func @generic_mismatched_num_arguments(%arg0: memref) { - // expected-error @+1 {{op expected fun arguments to match number of views}} + // expected-error @+1 {{op expected function arguments to match number of operands}} linalg.generic { args_in = 0, args_out = 1, @@ -111,7 +111,7 @@ func @foo(%0: i32) { return } func @generic_mismatched_num_returns(%arg0: memref) { - // expected-error @+1 {{op expected fun results to match number of output views}} + // expected-error @+1 {{op expected function results(0) to match number of outputs(1)}} linalg.generic { args_in = 0, args_out = 1, @@ -123,6 +123,36 @@ // ----- +func @foo(%0: i32, %1: i32, %2: i32) { return } + +func @generic_mismatched_num_returns(%0: memref, %1: memref) { + // expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of operand 2}} + linalg.generic { + args_in = 3, + args_out = 0, + fun = @foo, + indexing_maps = [ affine_map<() -> (0)> ], + iterator_types = [] + } %0, %1, %1: memref, memref, memref +} + +// ----- + +func @foo(%0: i32, %1: i32, %2: f32) -> i32 { return %1: i32} + +func @generic_mismatched_num_returns(%0: memref, %1: memref) { + // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'f32' of output 1}} + linalg.generic { + args_in = 2, + args_out = 1, + fun = @foo, + indexing_maps = [ affine_map<() -> (0)> ], + iterator_types = [] + } %0, %0, %1: memref, memref, memref +} + +// ----- + func @foo(%0: i32) -> i32 { return %0: i32 } func @generic_symbol_in_map(%arg0: memref) { @@ -189,7 +219,7 @@ } func @generic_fun_arg_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected fun argument 0 of the same type as elemental type 'f32' of view 0}} + // expected-error @+1 {{op expected function argument 1 of the same type as elemental type 'f32' of operand 1}} linalg.generic { args_in = 0, args_out = 1, @@ -207,7 +237,7 @@ } func @generic_fun_result_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'f32' of view 0}} + // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'f32' of output 1}} linalg.generic { args_in = 0, args_out = 1, @@ -257,7 +287,7 @@ // ----- func @generic_mismatched_num_arguments(%arg0: memref) { - // expected-error @+1 {{op expected number of block arguments to match number of views}} + // expected-error @+1 {{op expected number of block arguments to match number of operands}} linalg.generic { args_in = 0, args_out = 1, @@ -271,7 +301,7 @@ // ----- func @generic_block_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: 'memref'}} + // expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output operand: 'memref'}} linalg.generic { args_in = 0, args_out = 1, @@ -285,7 +315,7 @@ // ----- func @indexed_generic_block_arg_count(%arg0: memref) { - // expected-error @+1 {{op expected number of block arguments to match number of views + number of loops}} + // expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -299,7 +329,7 @@ // ----- func @indexed_generic_block_induction_var_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected block argument 0 to be of IndexType}} + // expected-error @+1 {{op expected block argument 1 to be an index}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -313,7 +343,7 @@ // ----- func @indexed_generic_block_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output view: 'memref'}} + // expected-error @+1 {{op expected block argument 2 of the same type as elemental type of output operand: 'memref'}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -330,7 +360,7 @@ return %f : f32 } func @indexed_generic_fun_arg_count(%arg0: memref) { - // expected-error @+1 {{op expected fun arguments to match number of views + number of loops}} + // expected-error @+1 {{op expected function arguments to match number of loops + number of operands}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -346,7 +376,7 @@ return %val : f32 } func @indexed_generic_fun_induction_var_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected fun argument 0 to be of IndexType}} + // expected-error @+1 {{op expected function argument 1 to be an index}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -362,7 +392,7 @@ return %val : i1 } func @indexed_generic_fun_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected fun argument 1 of the same type as elemental type 'f32' of view 0}} + // expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of input 1}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -378,7 +408,7 @@ return %val, %val : i1, i1 } func @indexed_generic_fun_result_count(%arg0: memref) { - // expected-error @+1 {{op expected fun results to match number of output views}} + // expected-error @+1 {{op expected function results to match number of outputs}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -395,7 +425,7 @@ return %val_float : f32 } func @indexed_generic_fun_result_count(%arg0: memref) { - // expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'i32' of view 0}} + // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'i32' of output 1}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -408,7 +438,7 @@ // ----- func @generic_fun_result_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+9 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}} + // expected-error @+9 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}} linalg.generic { args_in = 0, args_out = 1, @@ -438,36 +468,6 @@ // ----- -func @generic_result_tensor_count(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected #output tensor operands (0) to match #results (1)}} - %0 = linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"] - } %arg0 { - ^bb(%i: f32): - linalg.yield %i: f32 - }: memref(off + i)>> -> tensor -} - -// ----- - -func @generic_result_tensor_type(%arg0: tensor) { - // expected-error @+1 {{op result #0 must be 'tensor', but got 'tensor'}} - %0 = linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"] - } %arg0 { - ^bb(%i: f32): - linalg.yield %i: f32 - }: tensor -> tensor -} - -// ----- - func @generic_fun_result_0_element_type(%arg0: memref) { // expected-error @+1 {{'linalg.dot' op expected 3 or more operands}} linalg.dot(%arg0, %arg0): memref, memref diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -157,23 +157,23 @@ // CHECK-LABEL: func @generic_with_tensor_input // CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, memref -func @generic_with_tensor_output(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: tensor) -> (tensor) { - %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : memref, offset: ?, strides: [?, 1]>, tensor -> tensor - return %0 : tensor +#trait2 = { + args_in = 2, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel", "parallel"], + fun = @foo, + library_call = "some_external_function_name_1" } -// CHECK-LABEL: func @generic_with_tensor_output -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref, #[[strided2D]]>, tensor -> tensor -// CHECK: return {{.*}} : tensor - func @generic_with_tensor_input_and_output(%arg0: tensor>, %arg1: tensor) -> (tensor) { - %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : tensor>, tensor -> tensor + %0 = linalg.generic #trait2 %arg0, %arg1 {foo = 1} : tensor>, tensor -> tensor return %0 : tensor } // CHECK-LABEL: func @generic_with_tensor_input_and_output -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, tensor -> tensor +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, tensor -> tensor // CHECK: return {{.*}} : tensor -#trait2 = { +#trait3 = { args_in = 1, args_out = 1, indexing_maps = #accesses, @@ -181,7 +181,7 @@ library_call = "some_external_function_name_2" } func @generic_region(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.generic #trait2 %arg0, %arg1 { + linalg.generic #trait3 %arg0, %arg1 { ^bb(%a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 } {foo = 1}: memref, offset: ?, strides: [?, 1]>, memref @@ -194,7 +194,7 @@ // CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, memref func @indexed_generic(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.indexed_generic #trait2 %arg0, %arg1 { + linalg.indexed_generic #trait3 %arg0, %arg1 { ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 } {foo = 1}: memref, offset: ?, strides: [?, 1]>, memref