diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -317,7 +317,7 @@ /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!$_op.isOutputTensor(opOperand)) + if (!$_op.isOutput(opOperand)) return false; return payloadUsesValueFromOperand(opOperand); }] @@ -613,7 +613,13 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getInputAndOutputOperands(); + OpOperandVector result; + result.reserve($_op.getNumOperands()); + llvm::transform( + this->getOperation()->getOpOperands(), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; }] >, //===------------------------------------------------------------------===// @@ -691,13 +697,8 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ SmallVector res; - // MLIR currently does not support dependent interfaces or interface - // inheritance. By construction all ops with StructuredOpInterface must - // implement DestinationStyleOpInterface. - // TODO: reevalute the need for a cast when a better mechanism exists. - auto iface = cast(*this->getOperation()); - for (OpOperand *opOperand : iface.getInputAndOutputOperands()) - llvm::append_range(res, getShape(opOperand)); + for (OpOperand &opOperand : this->getOperation()->getOpOperands()) + llvm::append_range(res, getShape(&opOperand)); return res; }] >, @@ -789,9 +790,8 @@ // TODO: reevalute the need for a cast when a better mechanism exists. //========================================================================// - ValueRange getInputs() { - return cast(*this->getOperation()) - .getInputs(); + OperandRange getInputs() { + return cast(*this->getOperation()).getInputs(); } int64_t getNumInputs() { @@ -799,7 +799,7 @@ .getNumInputs(); } - ValueRange getOutputs() { + OperandRange getOutputs() { return cast(*this->getOperation()) .getOutputs(); } @@ -809,12 +809,7 @@ .getNumOutputs(); } - int64_t getNumInputsAndOutputs() { - return cast(*this->getOperation()) - .getNumInputsAndOutputs(); - } - - OpOperandVector getInputOperands() { + MutableArrayRef getInputOperands() { return cast(*this->getOperation()) .getInputOperands(); } @@ -824,17 +819,12 @@ .getInputOperand(i); } - OpOperandVector getInputBufferOperands() { - return cast(*this->getOperation()) - .getInputBufferOperands(); - } - - OpOperandVector getInputTensorOperands() { + void setOutputOperand(int64_t i, Value value) { return cast(*this->getOperation()) - .getInputTensorOperands(); + .setOutputOperand(i, value); } - OpOperandVector getOutputOperands() { + MutableArrayRef getOutputOperands() { return cast(*this->getOperation()) .getOutputOperands(); } @@ -844,44 +834,14 @@ .getOutputOperand(i); } - void setOutputOperand(int64_t i, Value value) { - return cast(*this->getOperation()) - .setOutputOperand(i, value); - } - - OpOperandVector getOutputBufferOperands() { - return cast(*this->getOperation()) - .getOutputBufferOperands(); - } - - OpOperandVector getOutputTensorOperands() { + bool isInput(OpOperand *opOperand) { return cast(*this->getOperation()) - .getOutputTensorOperands(); + .isInput(opOperand); } - SmallVector getOutputBufferTypes() { + bool isOutput(OpOperand *opOperand) { return cast(*this->getOperation()) - .getOutputBufferTypes(); - } - - SmallVector getOutputTensorTypes() { - return cast(*this->getOperation()) - .getOutputTensorTypes(); - } - - OpOperandVector getInputAndOutputOperands() { - return cast(*this->getOperation()) - .getInputAndOutputOperands(); - } - - bool isInputTensor(OpOperand *opOperand) { - return cast(*this->getOperation()) - .isInputTensor(opOperand); - } - - bool isOutputTensor(OpOperand *opOperand) { - return cast(*this->getOperation()) - .isOutputTensor(opOperand); + .isOutput(opOperand); } bool isScalar(OpOperand *opOperand) { @@ -938,331 +898,177 @@ let verifyWithRegions = 1; } -// The 'DestinationStyleOpInterface' provides access to the methods relevant -// for destination-style ops. A destination-style operation has 'n' input -// arguments and 'm' output arguments. Each op that wants to implement -// DestinationStyleOpInterface needs to define getInputs() and getOutputs() -// methods. +// Ops that are in destination style have designated output operands, which act +// as initial tensor values for the results of the operation or the output +// buffers to which the results of the op will be written. +// +// Output operands must be tensors or memrefs. Input operands can have any +// type. All non-output operands are inputs. + +// It is assumed that the inputs of the op are the operands at position [0; +// getNumOperands() - getNumOutputs()). The outputs of the op are the operands +// at position [getNumOperands() - getNumOutputs(); getNumOperands()). In other +// words, all input operands come first. + +// If the op has "tensor semantics", then the input operands are either scalars +// or tensors. The output operands are tensors and every tensor output is tied +// to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output +// tensor is tied to the i-th OpResult. The op may not have any additional +// OpResults. Output operands and their tied OpResults have the same type. +// +// If the op has "buffer semantics", then the input operands are either memrefs +// or other non-tensor types, e.g. scalar types. Furthermore, the output +// operands are memrefs and the op has no results. +// +// Destination-passing style abstraction makes certain transformations easier. +// For example, tiling implementation can extract/insert slices from/into the +// destination of an op and use the resulting shaped value as an iter_arg in +// the surrounding loop structure. As another example, bufferization does not +// have to allocate new buffers for destinations (in case of in-place +// bufferization) and can directly reuse the existing destination buffer. +// +// Example of a destination style op: `%r = tensor.insert_slice %t into %d`, +// where `%t` is the single input and `%d` is the single output. `%d` is tied +// to `%r`. +// +// Example of an op that is not in destination style: `%r = tensor.pad %t`. +// This op is not in destination style because `%r` and `%t` have different +// shape. +// +// Each op that wants to implement DestinationStyleOpInterface needs to define +// the getNumOutputs() method. def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { let cppNamespace = "::mlir::linalg"; let methods = [ //===------------------------------------------------------------------===// - // Num input/output arguments handling. + // Operands handling. //===------------------------------------------------------------------===// - // `getInputs` must be defined by each op that wants to implement the - // DestinationStyleOpInterface. - InterfaceMethod< - /*desc=*/[{ - Return the input shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"getInputs", - /*args=*/(ins) - >, - // These special methods rely on `getInputs` and `getOutputs` being defined - // by each op that wants to implement the DestinationStyleOpInterface. - InterfaceMethod< - /*desc=*/[{ - Return the number of inputs. - }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op.getInputs().size(); - }] - >, - // `getOutputs` must be defined by each op that wants to implement the - // DestinationStyleOpInterface. - InterfaceMethod< - /*desc=*/[{ - Return the output shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"getOutputs", - /*args=*/(ins) - >, - InterfaceMethod< - /*desc=*/[{ - Return the number of outputs. - }], + // This method has to be defined for every DPS op. The default + // implementation returns 1. For DPS ops with multiple results this method + // has to be overriden. + // The operand list is assumed to start with the input operands and end + // with the output operands. Therefore, all methods to access the inputs + // and outputs can be expressed if the number of output operands is know. + InterfaceMethod< + /*desc=*/[{ Return the number of outputs. }], /*retTy=*/"int64_t", /*methodName=*/"getNumOutputs", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op.getOutputs().size(); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the number of inputs and outputs. - }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputsAndOutputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return this->getOperation()->getNumOperands(); - }] - >, - //===------------------------------------------------------------------===// - // Input operands handling. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/[{ - Return the input operands. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - int64_t numInputs = getNumInputs(); - OpOperandVector result; - result.reserve(numInputs); - llvm::transform( - this->getOperation()->getOpOperands().take_front(numInputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the `i`-th input operand. - }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getInputOperand", - /*args=*/(ins "int64_t":$i), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumInputs()); - return &this->getOperation()->getOpOperand(i); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the subset of input operands that are of buffer type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputBufferOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; - }] + /*defaultImplementation=*/"return $_op.getNumOutputs();" >, InterfaceMethod< - /*desc=*/[{ - Return the subset of input operands that are of tensor type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputTensorOperands", + /*desc=*/"Return the output shape operands.", + /*retTy=*/"OperandRange", + /*methodName=*/"getOutputs", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + MutableArrayRef operands = getOutputOperands(); + return OperandRange(operands.data(), operands.size()); }] >, - //===------------------------------------------------------------------===// - // Output operands handling. - //===------------------------------------------------------------------===// InterfaceMethod< - /*desc=*/[{ - Return the output operands. - }], - /*retTy=*/"OpOperandVector", + /*desc=*/"Return the output operands.", + /*retTy=*/"MutableArrayRef", /*methodName=*/"getOutputOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numOutputs = getNumOutputs(); - OpOperandVector result; - result.reserve(numOutputs); - llvm::transform( - this->getOperation()->getOpOperands() - .take_back(numOutputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; + return MutableArrayRef(&this->getOperation()->getOpOperand(getNumInputs()), + $_op.getNumOutputs()); }] >, InterfaceMethod< - /*desc=*/[{ - Return the `i`-th output operand. - }], + /*desc=*/"Return the `i`-th output operand.", /*retTy=*/"OpOperand*", /*methodName=*/"getOutputOperand", /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); + assert(i >= 0 && i < $_op.getNumOutputs()); return &this->getOperation()->getOpOperand(getNumInputs() + i); }] >, InterfaceMethod< - /*desc=*/[{ - Set the `i`-th output operand. - }], + /*desc=*/"Set the `i`-th output operand.", /*retTy=*/"void", /*methodName=*/"setOutputOperand", /*args=*/(ins "int64_t":$i, "Value":$value), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); + assert(i >= 0 && i < $_op.getNumOutputs()); this->getOperation()->setOperand(getNumInputs() + i, value); }] >, InterfaceMethod< - /*desc=*/[{ - Return the subset of output operands that are of buffer type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputBufferOperands", + /*desc=*/"Return the number of inputs.", + /*retTy=*/"int64_t", + /*methodName=*/"getNumInputs", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + return this->getOperation()->getNumOperands() - $_op.getNumOutputs(); }] >, InterfaceMethod< - /*desc=*/[{ - Return the subset of output operands that are of tensor type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputTensorOperands", + /*desc=*/"Return the input operands.", + /*retTy=*/"OperandRange", + /*methodName=*/"getInputs", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + MutableArrayRef operands = getInputOperands(); + return OperandRange(operands.data(), operands.size()); }] >, InterfaceMethod< - /*desc=*/[{ - Return the types of the subset of output operands that are of buffer type. - }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputBufferTypes", + /*desc=*/"Return the input operands.", + /*retTy=*/"MutableArrayRef", + /*methodName=*/"getInputOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputBufferOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; + return MutableArrayRef(&this->getOperation()->getOpOperand(0), + getNumInputs()); }] >, InterfaceMethod< - /*desc=*/[{ - Return the types of the subset of output operands that are of tensor type. - }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputTensorTypes", - /*args=*/(ins), + /*desc=*/[{ Return the `i`-th input operand. }], + /*retTy=*/"OpOperand*", + /*methodName=*/"getInputOperand", + /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputTensorOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; + assert(i >= 0 && i < getNumInputs()); + return &this->getOperation()->getOpOperand(i); }] >, //===------------------------------------------------------------------===// // Input and Output arguments handling. //===------------------------------------------------------------------===// InterfaceMethod< - /*desc=*/[{ - Return the range over input and output operands. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputAndOutputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - int64_t numInputsAndOutputs = getNumInputsAndOutputs(); - OpOperandVector result; - result.reserve(numInputsAndOutputs); - llvm::transform( - this->getOperation()->getOpOperands(), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return true if `opOperand` is an input tensor. - }], + /*desc=*/"Return true if `opOperand` is an input.", /*retTy=*/"bool", - /*methodName=*/"isInputTensor", + /*methodName=*/"isInput", /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() < $_op.getNumInputs()) - return true; - return false; + return opOperand->getOperandNumber() < $_op.getNumInputs(); }] >, InterfaceMethod< - /*desc=*/[{ - Return true if `opOperand` is an output tensor. - }], + /*desc=*/"Return true if `opOperand` is an output.", /*retTy=*/"bool", - /*methodName=*/"isOutputTensor", + /*methodName=*/"isOutput", /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() >= $_op.getNumInputs()) - return true; - return false; + return opOperand->getOperandNumber() >= $_op.getNumInputs(); }] >, InterfaceMethod< - /*desc=*/[{ - Return true if the `opOperand` is a scalar value. - }], + /*desc=*/"Return true if the `opOperand` is a scalar value.", /*retTy=*/"bool", /*methodName=*/"isScalar", /*args=*/(ins "OpOperand*":$opOperand), @@ -1273,9 +1079,7 @@ }] >, InterfaceMethod< - /*desc=*/[{ - Return the result tied to `opOperand`. - }], + /*desc=*/"Return the result tied to `opOperand`.", /*retTy=*/"OpResult", /*methodName=*/"getTiedOpResult", /*args=*/(ins "OpOperand*":$opOperand), @@ -1292,9 +1096,7 @@ // Other interface methods. //===------------------------------------------------------------------===// InterfaceMethod< - /*desc=*/[{ - Return whether the op has only MemRef input and outputs. - }], + /*desc=*/"Return whether the op has only MemRef input and outputs.", /*retTy=*/"bool", /*methodName=*/"hasBufferSemantics", /*args=*/(ins), @@ -1309,9 +1111,7 @@ }] >, InterfaceMethod< - /*desc=*/[{ - Return whether the op has only RankedTensor input and outputs. - }], + /*desc=*/"Return whether the op has only RankedTensor input and outputs.", /*retTy=*/"bool", /*methodName=*/"hasTensorSemantics", /*args=*/(ins), diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -215,6 +215,7 @@ getRegionBuilder() { return nullptr; } + int64_t getNumOutputs() { return getOutputs().size(); } }]; let hasCanonicalizer = 1; @@ -275,14 +276,12 @@ } // Implement functions necessary for DestinationStyleOpInterface. - mlir::ValueRange getOutputs() { return getInits(); } - unsigned getNumInputs() { return getInputs().size(); }; - unsigned getNumOutputs() { return getInits().size(); }; static std::function)> getRegionBuilder() { return nullptr; } + int64_t getNumOutputs() { return getInits().size(); } }]; let hasCustomAssemblyFormat = 1; diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -29,9 +29,9 @@ SmallVector argTypes; SmallVector argLocs; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType())); - argLocs.push_back(opOperand->get().getLoc()); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + argTypes.push_back(getElementTypeOrSelf(opOperand.get().getType())); + argLocs.push_back(opOperand.get().getLoc()); } ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -165,48 +165,58 @@ LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation() << " and " << *dst.getOperation() << "\n"); if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { - for (OpOperand *dstOpOperand : dst.getInputOperands()) { + for (OpOperand &dstOpOperand : dst.getInputOperands()) { + if (!dstOpOperand.get().getType().isa()) + continue; // Check if the operand is defined by the src. - auto definingOp = dstOpOperand->get().getDefiningOp(); + auto definingOp = dstOpOperand.get().getDefiningOp(); if (definingOp && definingOp == src) - addDependenceElem(DependenceType::RAW, dstOpOperand->get(), - dstOpOperand); + addDependenceElem(DependenceType::RAW, dstOpOperand.get(), + &dstOpOperand); } - for (OpOperand *dstOpOperand : dst.getOutputOperands()) { + for (OpOperand &dstOpOperand : dst.getOutputOperands()) { // Check if the operand is defined by the src. - auto definingOp = dstOpOperand->get().getDefiningOp(); + auto definingOp = dstOpOperand.get().getDefiningOp(); if (definingOp && definingOp == src) { - if (dst.isInitTensor(dstOpOperand)) { - addDependenceElem(DependenceType::RAW, dstOpOperand->get(), - dstOpOperand); + if (dst.isInitTensor(&dstOpOperand)) { + addDependenceElem(DependenceType::RAW, dstOpOperand.get(), + &dstOpOperand); } - addDependenceElem(DependenceType::WAW, dstOpOperand->get(), - dstOpOperand); + addDependenceElem(DependenceType::WAW, dstOpOperand.get(), + &dstOpOperand); } } return; } assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && "unhandled dependence tracking for mixed buffer/tensor operations"); - for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W + for (OpOperand &srcOpOperand : src.getOutputOperands()) { // W // RAW graph - for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R - if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias - addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand); + for (OpOperand &dstOpOperand : dst.getInputOperands()) { // R + if (!dstOpOperand.get().getType().isa()) + continue; + if (aliases.alias(srcOpOperand.get(), dstOpOperand.get())) // RAW alias + addDependenceElem(DependenceType::RAW, &srcOpOperand, &dstOpOperand); + } // WAW graph - for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W - if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias - addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand); + for (OpOperand &dstOpOperand : dst.getOutputOperands()) // W + if (aliases.alias(srcOpOperand.get(), dstOpOperand.get())) // WAW alias + addDependenceElem(DependenceType::WAW, &srcOpOperand, &dstOpOperand); } - for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R + for (OpOperand &srcOpOperand : src.getInputOperands()) { // R + if (!srcOpOperand.get().getType().isa()) + continue; // RAR graph - for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R - if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias - addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand); + for (OpOperand &dstOpOperand : dst.getInputOperands()) { // R + if (!dstOpOperand.get().getType().isa()) + continue; + if (aliases.alias(srcOpOperand.get(), dstOpOperand.get())) // RAR alias + addDependenceElem(DependenceType::RAR, &srcOpOperand, &dstOpOperand); + } // WAR graph - for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W - if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias - addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand); + for (OpOperand &dstOpOperand : dst.getOutputOperands()) // W + if (aliases.alias(srcOpOperand.get(), dstOpOperand.get())) // WAR alias + addDependenceElem(DependenceType::WAR, &srcOpOperand, &dstOpOperand); } } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -31,10 +31,10 @@ bool linalg::detail::canOpOperandsBeDroppedImpl( linalg::LinalgOp linalgOp, ArrayRef droppedOperands) { SmallVector indexingMaps; - for (auto *opOperand : linalgOp.getInputAndOutputOperands()) { - if (llvm::is_contained(droppedOperands, opOperand)) + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + if (llvm::is_contained(droppedOperands, &opOperand)) continue; - indexingMaps.push_back(linalgOp.getMatchingIndexingMap(opOperand)); + indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand)); } return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); } @@ -491,9 +491,9 @@ SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, Location loc) { SmallVector res; - for (OpOperand *opOperand : getInputAndOutputOperands()) { - for (int64_t i = 0, e = getRank(opOperand); i < e; ++i) - res.push_back(createFoldedDimOp(b, loc, opOperand->get(), i)); + for (OpOperand &opOperand : getOperation()->getOpOperands()) { + for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i) + res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i)); } return res; } @@ -501,8 +501,8 @@ SmallVector LinalgOp::createFlatListOfOperandStaticDims() { SmallVector res; assert(!hasDynamicShape() && "expected operands to have static shapes"); - for (OpOperand *opOperand : getInputAndOutputOperands()) - llvm::append_range(res, getShape(opOperand)); + for (OpOperand &opOperand : getOperation()->getOpOperands()) + llvm::append_range(res, getShape(&opOperand)); return res; } @@ -563,10 +563,10 @@ getResultsPositionInLoopsToShapeMap(LinalgOp &op) { int64_t inputRankSum = 0; int64_t outputRankSum = 0; - for (OpOperand *input : op.getInputOperands()) - inputRankSum += op.getRank(input); - for (OpOperand *output : op.getOutputOperands()) - outputRankSum += op.getRank(output); + for (OpOperand &input : op.getInputOperands()) + inputRankSum += op.getRank(&input); + for (OpOperand &output : op.getOutputOperands()) + outputRankSum += op.getRank(&output); return {inputRankSum, inputRankSum + outputRankSum}; } @@ -609,11 +609,11 @@ createFlatListOfOperandDims(b, loc)); int64_t pos = 0; ArrayRef shapeExprs = resultShapesFromInputShapesMap.getResults(); - for (OpOperand *opOperand : getOutputOperands()) { + for (OpOperand &opOperand : getOutputOperands()) { SmallVector shapes; - for (int64_t dim : llvm::seq(0, getRank(opOperand))) { + for (int64_t dim : llvm::seq(0, getRank(&opOperand))) { if (checkDimExpr.visit(shapeExprs[pos])) - shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim)); + shapes.push_back(createOrFoldDimOp(b, loc, opOperand.get(), dim)); else shapes.push_back( getValueOrCreateConstantIndexOp(b, loc, allResultDimValues[pos])); @@ -644,32 +644,32 @@ // All input/output operands must be indexed. if (static_cast(linalgOp.getIndexingMapsArray().size()) != - linalgOp.getNumInputsAndOutputs()) + linalgOp->getNumOperands()) return op->emitOpError("expected the number of indexing_map (") << linalgOp.getIndexingMapsArray().size() << ") to be equal to the number of input/output operands (" - << linalgOp.getNumInputsAndOutputs() << ")"; + << linalgOp->getNumOperands() << ")"; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); // Symbols disallowed. if (indexingMap.getNumSymbols() != 0) return op->emitOpError("unexpected symbols in indexing_map #") - << opOperand->getOperandNumber(); + << opOperand.getOperandNumber(); // Domain must be consistent. unsigned numLoops = linalgOp.getNumLoops(); if (indexingMap.getNumDims() != numLoops) return op->emitOpError("expected indexing_map #") - << opOperand->getOperandNumber() << " to have " << numLoops + << opOperand.getOperandNumber() << " to have " << numLoops << " dim(s) to match the number of loops"; - int64_t rank = linalgOp.getRank(opOperand); + int64_t rank = linalgOp.getRank(&opOperand); if (indexingMap.getNumResults() != rank) return op->emitOpError("expected operand rank (") << rank << ") to match the result rank of indexing_map #" - << opOperand->getOperandNumber() << " (" + << opOperand.getOperandNumber() << " (" << indexingMap.getNumResults() << ")"; } @@ -688,13 +688,13 @@ if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { for (int64_t &range : endLoopRangeValues) range -= 1; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); SmallVector startIndices = indexingMap.compose(startLoopRangeValues); SmallVector endIndices = indexingMap.compose(endLoopRangeValues); - ArrayRef shape = linalgOp.getShape(opOperand); + ArrayRef shape = linalgOp.getShape(&opOperand); for (auto dim : llvm::seq(0, shape.size())) { // Ignore dynamic dimension or the case that the dimension size is 0 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) @@ -725,17 +725,16 @@ if (indexingMap.getResult(dim).dyn_cast()) { if (inferredDimSize != shape[dim]) { return op->emitOpError("inferred input/output operand #") - << opOperand->getOperandNumber() - << " has shape's dimension #" << dim << " to be " - << inferredDimSize << ", but found " << shape[dim]; + << opOperand.getOperandNumber() << " has shape's dimension #" + << dim << " to be " << inferredDimSize << ", but found " + << shape[dim]; } } else { if (inferredDimSize > shape[dim]) { return op->emitOpError("inferred input/output operand #") - << opOperand->getOperandNumber() - << " has shape's dimension #" << dim - << " to be greater than or equal to " << inferredDimSize - << ", but found " << shape[dim]; + << opOperand.getOperandNumber() << " has shape's dimension #" + << dim << " to be greater than or equal to " + << inferredDimSize << ", but found " << shape[dim]; } } } @@ -755,7 +754,7 @@ // not used). Block &block = linalgOp->getRegion(0).front(); - if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments()) + if (linalgOp->getNumOperands() != block.getNumArguments()) return op->emitOpError("expected as many non-induction variable region " "arguments as the number of input/output operands"); @@ -781,30 +780,36 @@ // This means an op that constructs a tensor out of indices cannot be a // LinalgOp at the moment. For now this will have to be a special op until we // have output shape operands that are not tensors. - int64_t numInputs = dstStyleOp.getNumInputs(); int64_t numOutputs = dstStyleOp.getNumOutputs(); if (numOutputs == 0) return op->emitOpError("expected at least one output operand"); - if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) - return failure(); - // Verify the number of results matches the number of output tensors. - if (op->getNumResults() != dstStyleOp.getOutputTensorOperands().size()) - return op->emitOpError("expected the number of results (") - << op->getNumResults() - << ") to be equal to the number of output tensors (" - << dstStyleOp.getOutputTensorOperands().size() << ")"; + + SmallVector outputBufferOperands, outputTensorOperands; + for (OpOperand &operand : dstStyleOp.getOutputOperands()) { + Type type = operand.get().getType(); + if (type.isa()) + outputBufferOperands.push_back(&operand); + if (type.isa()) + outputTensorOperands.push_back(&operand); + } // Simplifying assumption: either full tensor or full buffer mode. // This allows simpler verification of output operands vs result types // without premature tracking of which operand is what in mixed-mode. // TODO: relax when mixed-mode needs to pass verification. - if (!dstStyleOp.getOutputBufferOperands().empty() && - !dstStyleOp.getOutputTensorOperands().empty()) + if (!outputBufferOperands.empty() && !outputTensorOperands.empty()) return op->emitOpError( "expected output operands to all have tensor type or " "all have buffer type"); - for (OpOperand *opOperand : dstStyleOp.getOutputTensorOperands()) { + // Verify the number of results matches the number of output tensors. + if (op->getNumResults() != outputTensorOperands.size()) + return op->emitOpError("expected the number of results (") + << op->getNumResults() + << ") to be equal to the number of output tensors (" + << outputTensorOperands.size() << ")"; + + for (OpOperand *opOperand : outputTensorOperands) { OpResult result = dstStyleOp.getTiedOpResult(opOperand); if (result.getType() != opOperand->get().getType()) return op->emitOpError("expected type of operand #") @@ -813,6 +818,5 @@ << " to match type of corresponding result (" << result.getType() << ")"; } - return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -837,10 +837,14 @@ &effects, ValueRange results, ValueRange inputBuffers, ValueRange outputs) { for (Value value : inputBuffers) { + if (!value.getType().isa()) + continue; effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); } for (Value value : outputs) { + if (!value.getType().isa()) + continue; effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), value, @@ -851,10 +855,8 @@ void GenericOp::getEffects( SmallVectorImpl> &effects) { - SmallVector inputBuffers = getInputBufferOperands(); - SmallVector outputBuffers = getOutputBufferOperands(); - getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, - outputBuffers); + getGenericEffectsImpl(effects, getOperation()->getResults(), getInputs(), + getOutputs()); } LogicalResult GenericOp::verify() { return success(); } @@ -887,7 +889,7 @@ // Check if there is any change to operands. if (newInputOperands.size() + newOutputOperands.size() == - static_cast(genericOp.getNumInputsAndOutputs())) + genericOp->getNumOperands()) return failure(); // Create the new op with the body being empty. @@ -939,35 +941,34 @@ SmallVector &newIndexingMaps) const { llvm::SmallDenseMap origToNewPos; llvm::SmallDenseMap, unsigned> dedupedInputs; - for (const auto &inputOpOperand : - llvm::enumerate(genericOp.getInputOperands())) { + for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) { + OpOperand *inputOpOperand = &en.value(); // Check if operand is dead and if dropping the indexing map makes the // loops to shape computation invalid. - if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) { + if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) { // Add the current operands to the list of potentially droppable // operands. If it cannot be dropped, this needs to be popped back. - droppedOpOperands.push_back(inputOpOperand.value()); + droppedOpOperands.push_back(inputOpOperand); if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) continue; droppedOpOperands.pop_back(); } // Check if this operand is a duplicate. - AffineMap indexingMap = - genericOp.getMatchingIndexingMap(inputOpOperand.value()); + AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand); auto it = dedupedInputs.find( - std::make_pair(inputOpOperand.value()->get(), indexingMap)); + std::make_pair(inputOpOperand->get(), indexingMap)); if (it != dedupedInputs.end()) { - origToNewPos[inputOpOperand.index()] = it->second; - droppedOpOperands.push_back(inputOpOperand.value()); + origToNewPos[en.index()] = it->second; + droppedOpOperands.push_back(inputOpOperand); continue; } // This is a preserved argument. - origToNewPos[inputOpOperand.index()] = newInputOperands.size(); - dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] = + origToNewPos[en.index()] = newInputOperands.size(); + dedupedInputs[{inputOpOperand->get(), indexingMap}] = newInputOperands.size(); - newInputOperands.push_back(inputOpOperand.value()->get()); + newInputOperands.push_back(inputOpOperand->get()); newIndexingMaps.push_back(indexingMap); } return origToNewPos; @@ -988,12 +989,11 @@ // If the op doesnt have tensor semantics, keep all the outputs as // preserved. if (!genericOp.hasTensorSemantics()) { - for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getOutputOperands())) { - origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); - newOutputOperands.push_back(outputOpOperand.value()->get()); + for (const auto &en : llvm::enumerate(genericOp.getOutputOperands())) { + origToNewPos[en.index()] = newOutputOperands.size(); + newOutputOperands.push_back(en.value().get()); newIndexingMaps.push_back( - genericOp.getMatchingIndexingMap(outputOpOperand.value())); + genericOp.getMatchingIndexingMap(&en.value())); } } else { // Output argument can be dropped if the result has @@ -1002,23 +1002,22 @@ // - the corresponding indexing maps are not needed for loop bound // computation. auto yieldOp = cast(genericOp.getBody()->getTerminator()); - for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getOutputOperands())) { - Value result = genericOp.getResult(outputOpOperand.index()); + for (const auto &en : llvm::enumerate(genericOp.getOutputOperands())) { + OpOperand *outputOpOperand = &en.value(); + Value result = genericOp.getResult(en.index()); AffineMap indexingMap = - genericOp.getMatchingIndexingMap(outputOpOperand.value()); - auto key = - std::make_tuple(outputOpOperand.value()->get(), indexingMap, - yieldOp->getOperand(outputOpOperand.index())); + genericOp.getMatchingIndexingMap(outputOpOperand); + auto key = std::make_tuple(outputOpOperand->get(), indexingMap, + yieldOp->getOperand(en.index())); // Do not drop an out if its value is used in the payload. - if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand)) { if (result.use_empty()) { // Check if the opoperand can be dropped without affecting loop // bound computation. Add the operand to the list of dropped op // operand for checking. If it cannot be dropped, need to pop the // value back. - droppedOpOperands.push_back(outputOpOperand.value()); + droppedOpOperands.push_back(outputOpOperand); if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { continue; } @@ -1032,17 +1031,17 @@ // - The same yield value is used. auto it = dedupedOutpts.find(key); if (it != dedupedOutpts.end()) { - origToNewPos[outputOpOperand.index()] = it->second; - droppedOpOperands.push_back(outputOpOperand.value()); + origToNewPos[en.index()] = it->second; + droppedOpOperands.push_back(outputOpOperand); continue; } } - origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); + origToNewPos[en.index()] = newOutputOperands.size(); dedupedOutpts[key] = newOutputOperands.size(); - newOutputOperands.push_back(outputOpOperand.value()->get()); + newOutputOperands.push_back(outputOpOperand->get()); newIndexingMaps.push_back( - genericOp.getMatchingIndexingMap(outputOpOperand.value())); + genericOp.getMatchingIndexingMap(outputOpOperand)); } } @@ -1064,24 +1063,26 @@ // Replace all arguments in the original op, with arguments from the // canonicalized op. auto updateReplacements = - [&](OpOperandVector &origOperands, OpOperandVector &newOperands, + [&](MutableArrayRef &origOperands, + MutableArrayRef &newOperands, const llvm::SmallDenseMap &map) { for (const auto &origOperand : llvm::enumerate(origOperands)) { auto it = map.find(origOperand.index()); if (it == map.end()) continue; - OpOperand *newOperand = newOperands[it->second]; - replacements[origOperand.value()->getOperandNumber()] = - newOpBlock->getArgument(newOperand->getOperandNumber()); + OpOperand &newOperand = newOperands[it->second]; + replacements[origOperand.value().getOperandNumber()] = + newOpBlock->getArgument(newOperand.getOperandNumber()); } }; - OpOperandVector origInputOperands = genericOp.getInputOperands(); - OpOperandVector newInputOperands = newOp.getInputOperands(); + MutableArrayRef origInputOperands = genericOp.getInputOperands(); + MutableArrayRef newInputOperands = newOp.getInputOperands(); updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos); - OpOperandVector origOutputOperands = genericOp.getOutputOperands(); - OpOperandVector newOutputOperands = newOp.getOutputOperands(); + MutableArrayRef origOutputOperands = + genericOp.getOutputOperands(); + MutableArrayRef newOutputOperands = newOp.getOutputOperands(); updateReplacements(origOutputOperands, newOutputOperands, origOutsToNewOutsPos); @@ -1234,10 +1235,8 @@ void ReduceOp::getEffects( SmallVectorImpl> &effects) { - SmallVector inputBuffers = getInputBufferOperands(); - SmallVector outputBuffers = getOutputBufferOperands(); - getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, - outputBuffers); + getGenericEffectsImpl(effects, getOperation()->getResults(), getInputs(), + getOutputs()); } static ParseResult parseDstStyleOp( @@ -1562,14 +1561,14 @@ LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { + for (OpOperand &opOperand : op->getOpOperands()) { // Linalg "inputs" may be either tensor or memref type. // tensor<0xelt_type> is a convention that may not always mean // "0 iterations". Only erase in cases we see memref<...x0x...>. - auto mt = opOperand->get().getType().dyn_cast(); + auto mt = opOperand.get().getType().dyn_cast(); if (!mt) continue; - if (llvm::is_contained(op.getShape(opOperand), 0)) { + if (llvm::is_contained(op.getShape(&opOperand), 0)) { rewriter.eraseOp(op); return success(); } @@ -1585,10 +1584,10 @@ PatternRewriter &rewriter) const override { // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = - llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { - if (opOperand->get().isa()) + llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { + if (opOperand.get().isa()) return false; - auto castOp = opOperand->get().getDefiningOp(); + auto castOp = opOperand.get().getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); }); if (!hasTensorCastOperand) @@ -1599,18 +1598,17 @@ SmallVector newOperands; newOperands.reserve(op->getNumOperands()); // Inputs may fold. - for (OpOperand *opOperand : op.getInputOperands()) { - auto tensorCastOp = opOperand->get().getDefiningOp(); + for (Value input : op.getInputs()) { + auto tensorCastOp = input.getDefiningOp(); newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.getSource() - : opOperand->get()); + : input); } // Init tensors may fold, in which case the resultType must also change. - for (OpOperand *opOperand : op.getOutputOperands()) { - auto tensorCastOp = opOperand->get().getDefiningOp(); + for (Value output : op.getOutputs()) { + auto tensorCastOp = output.getDefiningOp(); bool fold = canFoldIntoConsumerOp(tensorCastOp); - newOperands.push_back(fold ? tensorCastOp.getOperand() - : opOperand->get()); + newOperands.push_back(fold ? tensorCastOp.getOperand() : output); newResultTypes.push_back(newOperands.back().getType()); } // Clone op. @@ -1669,8 +1667,8 @@ OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); Value newOperand = rewriter.create(loc, resultType, outOperand->get()); - SmallVector newOperands = linalgOp.getInputOperands(); - SmallVector outputOperands = linalgOp.getOutputOperands(); + SmallVector newOperands = linalgOp.getInputs(); + SmallVector outputOperands = linalgOp.getOutputs(); outputOperands[resultNumber] = newOperand; newOperands.append(outputOperands.begin(), outputOperands.end()); @@ -1693,14 +1691,14 @@ /// For each of the operand in `operands` this function maps the static sizes of /// dimensions to their affine dim expressions. -static void populateMap(LinalgOp linalgOp, ArrayRef operands, +static void populateMap(LinalgOp linalgOp, MutableArrayRef operands, llvm::DenseMap &affineExprToSize) { - for (OpOperand *opOperand : operands) { - if (linalgOp.isScalar(opOperand)) + for (OpOperand &opOperand : operands) { + if (linalgOp.isScalar(&opOperand)) continue; - Value src = opOperand->get(); + Value src = opOperand.get(); auto sourceType = src.getType().cast(); - auto sourceMap = linalgOp.getMatchingIndexingMap(opOperand); + auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand); // Get the `sourceShape` of the `sourceType`. If the operand is a result of // `tensor.cast` operation and source of the cast operation has a static @@ -1743,7 +1741,7 @@ return; auto sourceType = src.getType().cast(); Type resultType = sourceType; - if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) { + if (sourceType.hasStaticShape() && linalgOp.isOutput(opOperand)) { resultTypes.push_back(resultType); return; } @@ -1776,7 +1774,7 @@ unsigned index = opOperand->getOperandNumber(); newOperands[index] = newOperand; } - if (linalgOp.isOutputTensor(opOperand)) + if (linalgOp.isOutput(opOperand)) resultTypes.push_back(resultType); } @@ -1803,8 +1801,7 @@ // For each of the affine dim expression, check if the size is known. If // known add that in the map. - populateMap(linalgOp, linalgOp.getInputAndOutputOperands(), - affineExprToSize); + populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize); SmallVector newOperands; SmallVector resultTypes; @@ -1812,12 +1809,12 @@ // `changeNeeded` is `false` if the operands of `linalgOp` require no // change in their types. bool changeNeeded = false; - newOperands.reserve(linalgOp.getNumInputsAndOutputs()); + newOperands.reserve(linalgOp->getNumOperands()); resultTypes.reserve(linalgOp.getNumOutputs()); // Iterate over all the operands and update the static sizes. - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - createNewOperandWithStaticSizes(loc, rewriter, opOperand, + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + createNewOperandWithStaticSizes(loc, rewriter, &opOperand, affineExprToSize, linalgOp, newOperands, resultTypes, changeNeeded); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -112,16 +112,16 @@ tileSizes[position] = sliceOp.getMixedSizes()[result.index()]; } - SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + SmallVector valuesToTile = linalgOp->getOperands(); SmallVector tiledOperands = makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile, tileOffsets, tileSizes, sizeBounds, /*omitPartialTileCheck=*/true); SmallVector resultTensorTypes; - for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) + for (OpOperand &opOperand : linalgOp.getOutputOperands()) resultTensorTypes.push_back( - tiledOperands[opOperand->getOperandNumber()].getType()); + tiledOperands[opOperand.getOperandNumber()].getType()); Operation *newOp = linalgOp.clone(rewriter, linalgLoc, resultTensorTypes, tiledOperands); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -41,12 +41,12 @@ // New input operands for the cloned op. SmallVector newInputBuffers; newInputBuffers.reserve(op.getNumInputs()); - for (OpOperand *opOperand : op.getInputOperands()) { - if (op.isScalar(opOperand)) { - newInputBuffers.push_back(opOperand->get()); + for (OpOperand &opOperand : op.getInputOperands()) { + if (op.isScalar(&opOperand)) { + newInputBuffers.push_back(opOperand.get()); continue; } - FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); + FailureOr buffer = getBuffer(rewriter, opOperand.get(), options); if (failed(buffer)) return failure(); newInputBuffers.push_back(*buffer); @@ -118,7 +118,7 @@ auto genericOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. - if (genericOp.isOutputTensor(&opOperand)) + if (genericOp.isOutput(&opOperand)) return {genericOp.getTiedOpResult(&opOperand)}; return {}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -68,17 +68,17 @@ if (!outputType || !outputType.hasStaticShape()) return failure(); - if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { - return operand->get().getType().isa(); + if (!llvm::all_of(genericOp.getInputs(), [](Value input) { + return input.getType().isa(); })) return failure(); // Make sure all element types are the same. - auto getOperandElementType = [](OpOperand *operand) { - return operand->get().getType().cast().getElementType(); + auto getOperandElementType = [](Value value) { + return value.getType().cast().getElementType(); }; - if (!llvm::all_equal(llvm::map_range(genericOp.getInputAndOutputOperands(), - getOperandElementType))) + if (!llvm::all_equal( + llvm::map_range(genericOp->getOperands(), getOperandElementType))) return failure(); // We can only handle the case where we have int/float elements. @@ -95,8 +95,8 @@ [](AffineMap map) { return map.isPermutation(); })) return failure(); - for (OpOperand *operand : genericOp.getOutputOperands()) { - if (genericOp.payloadUsesValueFromOperand(operand)) + for (OpOperand &operand : genericOp.getOutputOperands()) { + if (genericOp.payloadUsesValueFromOperand(&operand)) return failure(); } @@ -115,15 +115,15 @@ int numInputs = genericOp.getNumInputs(); SmallVector inputValues(numInputs); for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) { - if (!matchPattern(operand.value()->get(), + if (!matchPattern(operand.value().get(), m_Constant(&inputValues[operand.index()]))) return failure(); } // Identified this as a potential candidate for folding. Now check the // policy to see whether we are allowed to proceed. - for (auto *operand : genericOp.getInputOperands()) { - if (!controlFn(operand)) + for (OpOperand &operand : genericOp.getInputOperands()) { + if (!controlFn(&operand)) return failure(); } @@ -171,8 +171,8 @@ APIntOrFloatArray computeFnInputs; auto inputShapes = llvm::to_vector<4>( - llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { - return operand->get().getType().cast().getShape(); + llvm::map_range(genericOp.getInputs(), [](Value value) { + return value.getType().cast().getShape(); })); // Given a `linearIndex`, remap it to a linear index to access linalg op diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -194,7 +194,7 @@ } /// Create the peeled generic op with an empty body. - SmallVector outsOperands = genericOp.getOutputOperands(); + SmallVector outsOperands = genericOp.getOutputs(); outsOperands.append(newInitValues.begin(), newInitValues.end()); SmallVector resultTypes = llvm::to_vector(genericOp.getResultTypes()); resultTypes.append(newResultTypes.begin(), newResultTypes.end()); @@ -212,9 +212,7 @@ PatternRewriter &rewriter) const { /// Append all results from the peeledGenericOps as `ins` operand for the /// residual generic op. - SmallVector residualGenericOpOperands = llvm::to_vector( - llvm::map_range(genericOp.getInputOperands(), - [](OpOperand *operand) { return operand->get(); })); + SmallVector residualGenericOpOperands = genericOp.getInputs(); unsigned origNumResults = genericOp.getNumResults(); unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults(); SmallVector extraIns; @@ -226,8 +224,8 @@ /// Add indexing maps for the newly added operands. Use the same map /// as those used for the new results of the peeledGenericOp. auto indexingMaps = llvm::to_vector( - llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) { - return genericOp.getMatchingIndexingMap(operand); + llvm::map_range(genericOp.getInputOperands(), [&](OpOperand &operand) { + return genericOp.getMatchingIndexingMap(&operand); })); for (auto resultNum : llvm::seq(origNumResults, peeledGenericOpNumResults)) { @@ -235,8 +233,8 @@ indexingMaps.push_back( peeledGenericOp.getIndexingMapMatchingResult(result)); } - for (OpOperand *outOperand : genericOp.getOutputOperands()) - indexingMaps.push_back(genericOp.getMatchingIndexingMap(outOperand)); + for (OpOperand &outOperand : genericOp.getOutputOperands()) + indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand)); auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); return rewriter.create( @@ -263,8 +261,8 @@ genericOp, "only operations with tensor semantics are handled"); } - if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { - return !genericOp.getMatchingIndexingMap(outOperand).isPermutation(); + if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand &outOperand) { + return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation(); })) { return rewriter.notifyMatchFailure( genericOp, "unhandled decomposition of generic op with out operand not " diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -55,10 +55,9 @@ bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { GenericOp genericOp = dyn_cast_or_null(op); return genericOp && - llvm::all_of( - genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) { - return !typeConverter.isLegal(opOperand->get().getType()); - }); + llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) { + return !typeConverter.isLegal(opOperand.get().getType()); + }); } /// A conversion patttern for detensoring `linalg.generic` ops. diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -377,21 +377,21 @@ SmallVector reassociationMaps; SmallVector newInputOutputTypes; bool doCanonicalization = false; - for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context); + for (OpOperand &opOperand : genericOp->getOpOperands()) { + auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context); if (replacementInfo) { reassociationMaps.push_back(replacementInfo->reassociation); newIndexingMaps.push_back(replacementInfo->indexMap); newInputOutputTypes.push_back(replacementInfo->type); doCanonicalization |= - replacementInfo->type != opOperand->get().getType(); + replacementInfo->type != opOperand.get().getType(); } else { // If replaceUnitExtents cannot handle this case, maintain the same // type, indexing map, and create a set of mappings representing an // identity matrix. - newInputOutputTypes.push_back(opOperand->get().getType()); - newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(opOperand)); - int64_t origRank = genericOp.getRank(opOperand); + newInputOutputTypes.push_back(opOperand.get().getType()); + newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand)); + int64_t origRank = genericOp.getRank(&opOperand); auto maps = llvm::to_vector<8>(llvm::map_range( llvm::seq(0, origRank), [&](int64_t dim) -> Attribute { return AffineMapAttr::get( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -90,7 +90,7 @@ // Only allow fusing the producer of an input operand for now. // TODO: allow fusing the producer of an output operand. - if (!consumer.isInputTensor(fusedOperand)) + if (!consumer.isInput(fusedOperand)) return false; // Get the consumer index map. The number of results of the consumer index @@ -128,10 +128,10 @@ addToCoveredDims(operandMap); } - for (OpOperand *operand : producer.getInputOperands()) { + for (OpOperand &operand : producer.getInputOperands()) { AffineMap newIndexingMap = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - operand, producerResultIndexMap, consumerIndexMap); + &operand, producerResultIndexMap, consumerIndexMap); addToCoveredDims(newIndexingMap); } if (!coveredDims.all()) @@ -179,7 +179,7 @@ } } // TODO: allow fusing the producer of an output operand. - assert(consumer.isInputTensor(fusedOperand) && + assert(consumer.isInput(fusedOperand) && "expected producer of input operand"); // 3. Consumer input operands up to consumerIdx (exclusive). for (BlockArgument bbArg : consumerBlock.getArguments().take_front( @@ -267,7 +267,7 @@ auto producer = cast(producerResult.getOwner()); auto consumer = cast(fusedOperand->getOwner()); // TODO: allow fusing the producer of an output operand. - assert(consumer.isInputTensor(fusedOperand) && + assert(consumer.isInput(fusedOperand) && "expected producer of input operand"); // Compute the fused operands list and indexing maps. @@ -278,52 +278,53 @@ fusedOutputOperands.reserve(producer.getNumOutputs() + consumer.getNumOutputs()); fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs()); - fusedIndexMaps.reserve(producer.getNumInputsAndOutputs() + - consumer.getNumInputsAndOutputs()); + fusedIndexMaps.reserve(producer->getNumOperands() + + consumer->getNumOperands()); // In the following, numbering matches that of `generateFusedTensorOpRegion`. // 3. Consumer input operands/maps up to consumerIdx (exclusive). - SmallVector consumerInputs = consumer.getInputOperands(); - SmallVector::iterator it = - llvm::find(consumerInputs, fusedOperand); + auto consumerInputs = consumer.getInputOperands(); + auto *it = llvm::find_if(consumerInputs, [&](OpOperand &operand) { + return &operand == fusedOperand; + }); assert(it != consumerInputs.end() && "expected to find the consumer operand"); - for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { - fusedInputOperands.push_back(opOperand->get()); - fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); + for (OpOperand &opOperand : llvm::make_range(consumerInputs.begin(), it)) { + fusedInputOperands.push_back(opOperand.get()); + fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand)); } // 4. Splice in producer's input operands/maps. AffineMap producerResultIndexMap = producer.getIndexingMapMatchingResult(producerResult); - for (OpOperand *opOperand : producer.getInputOperands()) { - fusedInputOperands.push_back(opOperand->get()); + for (OpOperand &opOperand : producer.getInputOperands()) { + fusedInputOperands.push_back(opOperand.get()); // Compute indexing maps for the producer args in the fused operation. AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - opOperand, producerResultIndexMap, + &opOperand, producerResultIndexMap, consumer.getMatchingIndexingMap(fusedOperand)); fusedIndexMaps.push_back(map); } // 5. Remaining consumer's input operands/maps (drop past index // `consumerIdx`). - for (OpOperand *opOperand : + for (OpOperand &opOperand : llvm::make_range(std::next(it), consumerInputs.end())) { - fusedInputOperands.push_back(opOperand->get()); - fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); + fusedInputOperands.push_back(opOperand.get()); + fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand)); } // 6. Collect all of the producer outputs. - for (OpOperand *opOperand : producer.getOutputOperands()) { - fusedOutputOperands.push_back(opOperand->get()); + for (OpOperand &opOperand : producer.getOutputOperands()) { + fusedOutputOperands.push_back(opOperand.get()); AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - opOperand, producerResultIndexMap, + &opOperand, producerResultIndexMap, consumer.getMatchingIndexingMap(fusedOperand)); fusedIndexMaps.push_back(map); - fusedResultTypes.push_back(opOperand->get().getType()); + fusedResultTypes.push_back(opOperand.get().getType()); } // 7. All of consumer's output operands (skip operands: added by the builder). - for (OpOperand *opOperand : consumer.getOutputOperands()) { - fusedOutputOperands.push_back(opOperand->get()); - fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); - fusedResultTypes.push_back(opOperand->get().getType()); + for (OpOperand &opOperand : consumer.getOutputOperands()) { + fusedOutputOperands.push_back(opOperand.get()); + fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand)); + fusedResultTypes.push_back(opOperand.get().getType()); } // Generate the fused op. @@ -373,13 +374,13 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. - for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - if (!areElementwiseOpsFusable(opOperand)) + for (OpOperand &opOperand : genericOp->getOpOperands()) { + if (!areElementwiseOpsFusable(&opOperand)) continue; - if (!controlFn(opOperand)) + if (!controlFn(&opOperand)) continue; - FailureOr fusedOp = fuseElementwiseOps(rewriter, opOperand); + FailureOr fusedOp = fuseElementwiseOps(rewriter, &opOperand); if (succeeded(fusedOp)) { auto replacements = fusedOp.value()->getResults().take_back(genericOp.getNumResults()); @@ -721,18 +722,18 @@ SmallVector expandedOpOperands; expandedOpOperands.reserve(genericOp.getNumInputs()); - for (OpOperand *opOperand : genericOp.getInputOperands()) { - if (opOperand == fusableOpOperand) { + for (OpOperand &opOperand : genericOp.getInputOperands()) { + if (&opOperand == fusableOpOperand) { expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc() : collapsingReshapeOp.getSrc()); continue; } - if (genericOp.isInputTensor(opOperand)) { - AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); - auto opOperandType = opOperand->get().getType().cast(); + if (auto opOperandType = + opOperand.get().getType().dyn_cast()) { + AffineMap indexingMap = genericOp.getMatchingIndexingMap(&opOperand); RankedTensorType expandedOperandType = getExpandedType(opOperandType, indexingMap, expansionInfo); - if (expandedOperandType != opOperand->get().getType()) { + if (expandedOperandType != opOperand.get().getType()) { // Reshape the operand to get the right type. SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); @@ -745,22 +746,22 @@ /*isExpandingReshape=*/true))) return llvm::None; expandedOpOperands.push_back(rewriter.create( - genericOp.getLoc(), expandedOperandType, opOperand->get(), + genericOp.getLoc(), expandedOperandType, opOperand.get(), reassociation)); continue; } } - expandedOpOperands.push_back(opOperand->get()); + expandedOpOperands.push_back(opOperand.get()); } Location loc = genericOp.getLoc(); SmallVector outputs; - for (OpOperand *opOperand : genericOp.getOutputOperands()) { - AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); - auto opOperandType = opOperand->get().getType().cast(); + for (OpOperand &opOperand : genericOp.getOutputOperands()) { + AffineMap indexingMap = genericOp.getMatchingIndexingMap(&opOperand); + auto opOperandType = opOperand.get().getType().cast(); RankedTensorType expandedOutputType = getExpandedType(opOperandType, indexingMap, expansionInfo); - if (expandedOutputType != opOperand->get().getType()) { + if (expandedOutputType != opOperand.get().getType()) { SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); if (failed(reshapeLikeShapesAreCompatible( @@ -772,10 +773,10 @@ /*isExpandingReshape=*/true))) return llvm::None; outputs.push_back(rewriter.create( - genericOp.getLoc(), expandedOutputType, opOperand->get(), + genericOp.getLoc(), expandedOutputType, opOperand.get(), reassociation)); } else { - outputs.push_back(opOperand->get()); + outputs.push_back(opOperand.get()); } } @@ -833,20 +834,21 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { + for (OpOperand &opOperand : genericOp.getInputOperands()) { tensor::CollapseShapeOp reshapeOp = - opOperand->get().getDefiningOp(); + opOperand.get().getDefiningOp(); if (!reshapeOp) continue; // Fold only if // - The tensor reshape op is folding. // - All constraints of fusing with reshape by expansion are met. - if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || - (!controlFoldingReshapes(opOperand))) + if (!isFusableWithReshapeByDimExpansion(genericOp, &opOperand) || + (!controlFoldingReshapes(&opOperand))) continue; Optional> replacementValues = - fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); + fuseWithReshapeByExpansion(genericOp, reshapeOp, &opOperand, + rewriter); if (!replacementValues) return failure(); rewriter.replaceOp(genericOp, *replacementValues); @@ -1416,8 +1418,8 @@ // Get the input operands. auto inputOperands = llvm::to_vector( - llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) { - return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, + llvm::map_range(genericOp.getInputOperands(), [&](OpOperand &opOperand) { + return getCollapsedOpOperand(loc, genericOp, &opOperand, collapsingInfo, rewriter); })); @@ -1426,9 +1428,9 @@ SmallVector outputOperands; resultTypes.reserve(genericOp.getNumOutputs()); outputOperands.reserve(genericOp.getNumOutputs()); - for (OpOperand *output : genericOp.getOutputOperands()) { - Value newOutput = - getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter); + for (OpOperand &output : genericOp.getOutputOperands()) { + Value newOutput = getCollapsedOpOperand(loc, genericOp, &output, + collapsingInfo, rewriter); outputOperands.push_back(newOutput); resultTypes.push_back(newOutput.getType()); } @@ -1494,23 +1496,23 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { + for (OpOperand &opOperand : genericOp->getOpOperands()) { tensor::ExpandShapeOp reshapeOp = - opOperand->get().getDefiningOp(); + opOperand.get().getDefiningOp(); if (!reshapeOp) continue; SmallVector collapsableIterationDims = - getCollapsableIterationSpaceDims(genericOp, opOperand, + getCollapsableIterationSpaceDims(genericOp, &opOperand, reshapeOp.getReassociationIndices()); if (collapsableIterationDims.empty() || - !controlFoldingReshapes(opOperand)) { + !controlFoldingReshapes(&opOperand)) { continue; } Optional> replacements = collapseGenericOpIterationDims(genericOp, collapsableIterationDims, - opOperand, rewriter); + &opOperand, rewriter); if (!replacements) { return rewriter.notifyMatchFailure( genericOp, "failed to do the fusion by collapsing transformation"); @@ -1543,8 +1545,8 @@ PatternRewriter &rewriter) const override { if (!genericOp.hasTensorSemantics()) return failure(); - for (OpOperand *opOperand : genericOp.getInputOperands()) { - Operation *def = opOperand->get().getDefiningOp(); + for (OpOperand &opOperand : genericOp.getInputOperands()) { + Operation *def = opOperand.get().getDefiningOp(); TypedAttr constantAttr; auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { { @@ -1573,7 +1575,7 @@ return false; }; - auto resultValue = opOperand->get().dyn_cast(); + auto resultValue = opOperand.get().dyn_cast(); if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) continue; @@ -1583,21 +1585,21 @@ SmallVector fusedIndexMaps; SmallVector fusedOperands; SmallVector fusedLocs{genericOp.getLoc()}; - fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); + fusedIndexMaps.reserve(genericOp->getNumOperands()); fusedOperands.reserve(genericOp.getNumInputs()); fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); - for (OpOperand *inputOperand : genericOp.getInputOperands()) { - if (inputOperand == opOperand) + for (OpOperand &inputOperand : genericOp.getInputOperands()) { + if (&inputOperand == &opOperand) continue; - Value inputValue = inputOperand->get(); + Value inputValue = inputOperand.get(); fusedIndexMaps.push_back( - genericOp.getMatchingIndexingMap(inputOperand)); + genericOp.getMatchingIndexingMap(&inputOperand)); fusedOperands.push_back(inputValue); fusedLocs.push_back(inputValue.getLoc()); } - for (OpOperand *outputOperand : genericOp.getOutputOperands()) + for (OpOperand &outputOperand : genericOp.getOutputOperands()) fusedIndexMaps.push_back( - genericOp.getMatchingIndexingMap(outputOperand)); + genericOp.getMatchingIndexingMap(&outputOperand)); // Check if the operation shapes to loops map is computable. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { @@ -1609,7 +1611,7 @@ Value scalarConstant = rewriter.create( def->getLoc(), constantAttr, constantAttr.getType()); - SmallVector outputOperands = genericOp.getOutputOperands(); + SmallVector outputOperands = genericOp.getOutputs(); auto fusedOp = rewriter.create( rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), /*inputs=*/fusedOperands, @@ -1624,7 +1626,7 @@ Region ®ion = genericOp->getRegion(0); Block &entryBlock = *region.begin(); BlockAndValueMapping mapping; - mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), + mapping.map(entryBlock.getArgument(opOperand.getOperandNumber()), scalarConstant); Region &fusedRegion = fusedOp->getRegion(0); rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), @@ -1655,9 +1657,9 @@ rewriter.startRootUpdate(op); bool modifiedOutput = false; Location loc = op.getLoc(); - for (OpOperand *opOperand : op.getOutputOperands()) { - if (!op.payloadUsesValueFromOperand(opOperand)) { - Value operandVal = opOperand->get(); + for (OpOperand &opOperand : op.getOutputOperands()) { + if (!op.payloadUsesValueFromOperand(&opOperand)) { + Value operandVal = opOperand.get(); auto operandType = operandVal.getType().dyn_cast(); if (!operandType) continue; @@ -1681,7 +1683,7 @@ Value emptyTensor = rewriter.create( loc, operandType.getShape(), operandType.getElementType(), dynamicDims); - op->setOperand(opOperand->getOperandNumber(), emptyTensor); + op->setOperand(opOperand.getOperandNumber(), emptyTensor); } } if (!modifiedOutput) { @@ -1703,14 +1705,14 @@ return failure(); bool fillFound = false; Block &payload = genericOp.getRegion().front(); - for (OpOperand *opOperand : genericOp.getInputOperands()) { - if (!genericOp.payloadUsesValueFromOperand(opOperand)) + for (OpOperand &opOperand : genericOp.getInputOperands()) { + if (!genericOp.payloadUsesValueFromOperand(&opOperand)) continue; - FillOp fillOp = opOperand->get().getDefiningOp(); + FillOp fillOp = opOperand.get().getDefiningOp(); if (!fillOp) continue; fillFound = true; - payload.getArgument(opOperand->getOperandNumber()) + payload.getArgument(opOperand.getOperandNumber()) .replaceAllUsesWith(fillOp.value()); } return success(fillFound); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -68,7 +68,7 @@ bool fromSubViewOpOnly = false) { // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { + for (OpOperand &opOperand : op->getOpOperands()) { // The method `getRangeFromOperandShape` requires using SubViewOp or // ExtractSliceOps. If the value isn't defined from there continue. // todo: The method should be adapted to get the values from @@ -77,12 +77,12 @@ // `std` dialect and add the method to `ViewInterface`. if (fromSubViewOpOnly && !isa_and_nonnull( - opOperand->get().getDefiningOp())) + opOperand.get().getDefiningOp())) continue; - AffineMap map = op.getMatchingIndexingMap(opOperand); + AffineMap map = op.getMatchingIndexingMap(&opOperand); LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: " - << opOperand->getOperandNumber() << "\n"); + << opOperand.getOperandNumber() << "\n"); LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange map: " << map << "\n"); SmallVector shapeRanges(map.getNumResults(), nullptr); @@ -94,8 +94,8 @@ LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " << loopDepth << "\n"); LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: " - << opOperand->get() << "\n"); - return ShapeDimension{opOperand->get(), + << opOperand.get() << "\n"); + return ShapeDimension{opOperand.get(), static_cast(en.index())}; } } @@ -104,7 +104,7 @@ } static SmallVector getTiledOperands(LinalgOp producer) { - return producer.getInputAndOutputOperands(); + return producer->getOperands(); } /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` @@ -137,7 +137,7 @@ } SmallVector clonedShapes; - clonedShapes.reserve(producer.getNumInputsAndOutputs()); + clonedShapes.reserve(producer->getNumOperands()); // Compute subranges for all tensor input/output operands. clonedShapes.append(makeTiledShapes( @@ -150,15 +150,18 @@ // fully dynamic at construction time. SmallVector resultTypes; resultTypes.reserve(producer->getNumResults()); - for (RankedTensorType t : producer.getOutputTensorTypes()) { - unsigned rank = t.getRank(); + for (Type t : TypeRange{producer.getOutputs()}) { + auto tensorType = t.dyn_cast(); + if (!tensorType) + continue; + unsigned rank = tensorType.getRank(); SmallVector staticOffsetsVector( rank, ShapedType::kDynamicStrideOrOffset); SmallVector staticSizesVector(rank, ShapedType::kDynamicSize); SmallVector staticStridesVector( rank, ShapedType::kDynamicStrideOrOffset); resultTypes.push_back(tensor::ExtractSliceOp::inferResultType( - t.cast(), staticOffsetsVector, staticSizesVector, + tensorType, staticOffsetsVector, staticSizesVector, staticStridesVector)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -161,7 +161,7 @@ allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; } erase_value(tileIvs, OpFoldResult()); - SmallVector tiledOperands = producerOp.getInputAndOutputOperands(); + SmallVector tiledOperands = producerOp->getOperands(); tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, tileSizes, producerLoopBounds, /**omitPartialTileCheck=*/false); @@ -439,14 +439,16 @@ int64_t split = std::distance(iterTypes.begin(), it); // Helper to fuse the producers greedily using a queue of fusion candidates. - auto fuseProducersGreedily = [&](ArrayRef operands) { - SmallVector candidates(operands.begin(), operands.end()); + auto fuseProducersGreedily = [&](MutableArrayRef operands) { + SmallVector candidates = llvm::to_vector( + llvm::map_range(operands, [](OpOperand &operand) { return &operand; })); while (!candidates.empty()) { FailureOr fusedProducer = tileLoopNest.fuseProducer(b, candidates.pop_back_val()); if (failed(fusedProducer)) continue; - candidates.append(fusedProducer->getInputAndOutputOperands()); + for (OpOperand &operand : fusedProducer->getOperation()->getOpOperands()) + candidates.push_back(&operand); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -50,18 +50,19 @@ if (failed(generalizeNamedOpPrecondition(linalgOp))) return rewriter.notifyMatchFailure(linalgOp, "preconditions not met"); - SmallVector inputOperands = linalgOp.getInputOperands(); - SmallVector outputOperands = linalgOp.getOutputOperands(); + SmallVector inputOperands = linalgOp.getInputs(); + SmallVector outputOperands = linalgOp.getOutputs(); SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); SmallVector iterators = linalgOp.getIteratorTypesArray(); - SmallVector resultTypes = linalgOp.getOutputTensorTypes(); - SmallVector types(resultTypes.begin(), resultTypes.end()); + SmallVector resultTypes = linalgOp.hasTensorSemantics() + ? TypeRange{linalgOp.getOutputs()} + : TypeRange{}; // All named ops have a region attached that can be inlined. assert(linalgOp->getNumRegions() == 1 && "expect named op to have one region attached"); GenericOp genericOp = - rewriter.create(linalgOp.getLoc(), types, inputOperands, + rewriter.create(linalgOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps, iterators); rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(), genericOp.getRegion().begin()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -111,7 +111,7 @@ static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp) { for (OpOperand &use : padOp.getResult().getUses()) { auto linalgUser = dyn_cast(use.getOwner()); - if (!linalgUser || !linalgUser.isInputTensor(&use)) { + if (!linalgUser || !linalgUser.isInput(&use)) { LLVM_DEBUG(DBGS() << "Found a use of " << *(padOp) << "\nthat is not an input tensor of a LinalgOp, " << "cannot hoist\n" diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -41,24 +41,25 @@ SmallVector scalarOperands; SmallVector newIndexingMaps; SmallVector newOperands; - for (OpOperand *opOperand : genericOp.getInputOperands()) { - AffineMap map = genericOp.getMatchingIndexingMap(opOperand); - if (genericOp.isInputTensor(opOperand) && map.isConstant()) { - scalarOperands.emplace_back(opOperand->getOperandNumber()); + for (OpOperand &opOperand : genericOp.getInputOperands()) { + AffineMap map = genericOp.getMatchingIndexingMap(&opOperand); + if (genericOp.isInput(&opOperand) && map.isConstant()) { + scalarOperands.emplace_back(opOperand.getOperandNumber()); } else { newIndexingMaps.emplace_back(map); - newOperands.emplace_back(opOperand->get()); + newOperands.emplace_back(opOperand.get()); } } if (scalarOperands.empty()) return failure(); - for (OpOperand *opOperand : genericOp.getOutputOperands()) - newIndexingMaps.emplace_back(genericOp.getMatchingIndexingMap(opOperand)); + for (OpOperand &opOperand : genericOp.getOutputOperands()) + newIndexingMaps.emplace_back( + genericOp.getMatchingIndexingMap(&opOperand)); Location loc = genericOp->getLoc(); - SmallVector outputOperands = genericOp.getOutputOperands(); + SmallVector outputOperands = genericOp.getOutputs(); auto newOp = rewriter.create( loc, genericOp->getResultTypes(), newOperands, outputOperands, newIndexingMaps, genericOp.getIteratorTypesArray()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -67,8 +67,8 @@ // 2. Compute the interchanged indexing maps. SmallVector newIndexingMaps; - for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - AffineMap m = genericOp.getMatchingIndexingMap(opOperand); + for (OpOperand &opOperand : genericOp->getOpOperands()) { + AffineMap m = genericOp.getMatchingIndexingMap(&opOperand); if (!permutationMap.isEmpty()) m = m.compose(permutationMap); newIndexingMaps.push_back(m); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -131,29 +131,30 @@ assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); SmallVector indexedValues; - indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); + indexedValues.reserve(linalgOp->getNumOperands()); auto allIvsPlusDims = SmallVector(allIvs.begin(), allIvs.end()); // TODO: Avoid the loads if the corresponding argument of the // region has no uses. // 1.a. Emit load from input operand or for scalars access the operand itself. - for (OpOperand *inputOperand : linalgOp.getInputOperands()) { - if (linalgOp.isScalar(inputOperand)) { - indexedValues.push_back(inputOperand->get()); + for (OpOperand &inputOperand : linalgOp.getInputOperands()) { + if (linalgOp.isScalar(&inputOperand)) { + indexedValues.push_back(inputOperand.get()); continue; } auto indexing = makeCanonicalAffineApplies( - b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims); + b, loc, linalgOp.getMatchingIndexingMap(&inputOperand), allIvsPlusDims); indexedValues.push_back( - b.create(loc, inputOperand->get(), indexing)); + b.create(loc, inputOperand.get(), indexing)); } // 1.b. Emit load from output views. - for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { + for (OpOperand &outputOperand : linalgOp.getOutputOperands()) { SmallVector indexing = makeCanonicalAffineApplies( - b, loc, linalgOp.getMatchingIndexingMap(outputOperand), allIvsPlusDims); + b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), + allIvsPlusDims); indexedValues.push_back( - b.create(loc, outputOperand->get(), indexing)); + b.create(loc, outputOperand.get(), indexing)); } // TODO: When a region inliner exists, use it. @@ -161,11 +162,13 @@ // 3. Emit store. SmallVector, 8> indexing; SmallVector outputBuffers; - for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) { + for (OpOperand &outputOperand : linalgOp.getOutputOperands()) { + if (!outputOperand.get().getType().isa()) + continue; indexing.push_back(makeCanonicalAffineApplies( - b, loc, linalgOp.getMatchingIndexingMap(outputOperand), + b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), allIvsPlusDims)); - outputBuffers.push_back(outputOperand->get()); + outputBuffers.push_back(outputOperand.get()); } inlineRegionAndEmitStore(b, loc, linalgOp, indexedValues, indexing, outputBuffers); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -145,15 +145,15 @@ assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand"); auto vUseFullTileBuffers = options.useFullTileBuffers.value_or(llvm::SmallBitVector()); - vUseFullTileBuffers.resize(linalgOp.getNumInputsAndOutputs(), + vUseFullTileBuffers.resize(linalgOp->getNumOperands(), options.useFullTileBuffersDefault); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - int64_t operandNumber = opOperand->getOperandNumber(); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + int64_t operandNumber = opOperand.getOperandNumber(); if (options.operandsToPromote && !options.operandsToPromote->count(operandNumber)) continue; - Operation *op = opOperand->get().getDefiningOp(); + Operation *op = opOperand.get().getDefiningOp(); if (auto sv = dyn_cast_or_null(op)) { subViews[operandNumber] = sv; useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber]; @@ -326,13 +326,13 @@ // operands are not views. This is to support cases such as FillOp taking // extra scalars etc. Keep a reference to output buffers; SmallVector opViews; - opViews.reserve(op.getNumInputsAndOutputs()); + opViews.reserve(op->getNumOperands()); SmallVector, 8> writebackViews; writebackViews.reserve(promotedBuffersAndViews->size()); - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { - int64_t operandNumber = opOperand->getOperandNumber(); + for (OpOperand &opOperand : op->getOpOperands()) { + int64_t operandNumber = opOperand.getOperandNumber(); if (options.subViews.count(operandNumber) != 0) { - if (options.useFullTileBuffers[opOperand->get()]) + if (options.useFullTileBuffers[opOperand.get()]) opViews.push_back( (*promotedBuffersAndViews)[operandNumber].fullLocalView); else @@ -340,10 +340,10 @@ (*promotedBuffersAndViews)[operandNumber].partialLocalView); if (operandNumber >= op.getNumInputs()) writebackViews.emplace_back(std::make_pair( - opOperand->get(), + opOperand.get(), (*promotedBuffersAndViews)[operandNumber].partialLocalView)); } else { - opViews.push_back(opOperand->get()); + opViews.push_back(opOperand.get()); } } op->setOperands(0, opViews.size(), opViews); @@ -371,12 +371,12 @@ if (!linalgOp || !linalgOp.hasBufferSemantics()) return failure(); // Check that at least one of the requested operands is indeed a subview. - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand &opOperand : linalgOp->getOpOperands()) { auto sv = - isa_and_nonnull(opOperand->get().getDefiningOp()); + isa_and_nonnull(opOperand.get().getDefiningOp()); if (sv) { if (!options.operandsToPromote || - options.operandsToPromote->count(opOperand->getOperandNumber())) + options.operandsToPromote->count(opOperand.getOperandNumber())) return success(); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -116,8 +116,8 @@ SmallVector newInputs; SmallVector newMaps; // Calculate the new shapes and indexing maps of the input operands. - for (OpOperand *operand : op.getInputOperands()) { - AffineMap map = op.getMatchingIndexingMap(operand); + for (OpOperand &operand : op.getInputOperands()) { + AffineMap map = op.getMatchingIndexingMap(&operand); SmallVector newShape; SmallVector exprs; SmallVector reassociation; @@ -126,11 +126,11 @@ unsigned dim = map.getDimPosition(idx); if (reductionDim == dim) { if (control.innerParallel) { - newShape.push_back(op.getShape(operand)[idx] / ratio); + newShape.push_back(op.getShape(&operand)[idx] / ratio); newShape.push_back(ratio); } else { newShape.push_back(ratio); - newShape.push_back(op.getShape(operand)[idx] / ratio); + newShape.push_back(op.getShape(&operand)[idx] / ratio); } reassociation.push_back({index++, index++}); if (control.innerParallel) { @@ -143,7 +143,7 @@ } continue; } - newShape.push_back(op.getShape(operand)[idx]); + newShape.push_back(op.getShape(&operand)[idx]); if (control.innerParallel) { exprs.push_back( b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1)); @@ -156,15 +156,15 @@ newMaps.push_back( AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); // If the shape is unchanged the input doesn't change. - if (newShape == op.getShape(operand)) { - newInputs.push_back(operand->get()); + if (newShape == op.getShape(&operand)) { + newInputs.push_back(operand.get()); continue; } Type newType = RankedTensorType::get( newShape, - operand->get().getType().cast().getElementType()); + operand.get().getType().cast().getElementType()); Value newInput = b.create( - loc, newType, operand->get(), reassociation); + loc, newType, operand.get(), reassociation); newInputs.push_back(newInput); } @@ -234,7 +234,7 @@ // from the previous op. unsigned intermRank = newOutputShape.size(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); - SmallVector outputOperands = op.getOutputOperands(); + SmallVector outputOperands = op.getOutputs(); SmallVector reductionIteratorTypes; SmallVector exprs; for (unsigned i : llvm::seq(0, intermRank)) { @@ -386,9 +386,9 @@ // Step 2. Reindex / expand indexing maps. // Reindex existing input indexings: k -> k * splitFactor + k'. SmallVector newMaps; - newMaps.reserve(op.getNumInputsAndOutputs() + 1); - for (OpOperand *o : op.getInputOperands()) - newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor)); + newMaps.reserve(op->getNumOperands() + 1); + for (OpOperand &o : op.getInputOperands()) + newMaps.push_back(scaleReductionDim(op, o, reductionDimPos, splitFactor)); // Provision a new indexing for the shape-only tensor. auto nDims = op.getNumLoops() + 1; auto redDim = getAffineDimExpr(reductionDimPos, context); @@ -398,8 +398,8 @@ // TODO: a subset of these may not reduce along reducePos and should be // reindexed: k -> k * splitFactor + k', when multi-reduction support is // available. - for (OpOperand *o : op.getOutputOperands()) - newMaps.push_back(insertParallelDim(op, *o, reductionDimPos, + for (OpOperand &o : op.getOutputOperands()) + newMaps.push_back(insertParallelDim(op, o, reductionDimPos, reductionDimSize / splitFactor)); // Step 3. Handle operands. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -320,11 +320,11 @@ Operation *clonedOp = b.clone(*op.getOperation()); auto destinationStyleOp = dyn_cast(clonedOp); if (destinationStyleOp) { - for (OpOperand *outOperand : destinationStyleOp.getOutputOperands()) { - auto it = llvm::find(dest, outOperand->get()); + for (OpOperand &outOperand : destinationStyleOp.getOutputOperands()) { + auto *it = llvm::find(dest, outOperand.get()); assert(it != dest.end() && "dest operand not found in dest"); unsigned destNum = std::distance(dest.begin(), it); - outOperand->set(destBbArgs[destNum]); + outOperand.set(destBbArgs[destNum]); } } @@ -503,7 +503,7 @@ // Tile the `operandValuesToUse` that either match the `op` operands // themselves or the tile loop arguments forwarding them. assert(operandValuesToUse.size() == - static_cast(op.getNumInputsAndOutputs()) && + static_cast(op->getNumOperands()) && "expect the number of operands and inputs and outputs to match"); SmallVector valuesToTile = operandValuesToUse; SmallVector sizeBounds = diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -86,7 +86,7 @@ LinalgOpTy> { /// Return the destination operands. SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { - return cast(op).getOutputOperands(); + return cast(op).getOutputs(); } /// Return the loop iterator type. @@ -127,14 +127,12 @@ // specified could lead to out of bounds accesses. Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); - SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + SmallVector valuesToTile = linalgOp->getOperands(); SmallVector tiledOperands = makeTiledShapes( b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); - SmallVector resultTensorTypes = llvm::to_vector(llvm::map_range( - linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) { - return tiledOperands[opOperand->getOperandNumber()].getType(); - })); + SmallVector resultTensorTypes = + getTensorOutputTypes(linalgOp, tiledOperands); Operation *tiledOp = linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); @@ -224,23 +222,23 @@ return op->emitOpError("expected operation to have buffer semantics"); SmallVector indexedValues; - indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); + indexedValues.reserve(linalgOp->getNumOperands()); Location linalgOpLoc = op->getLoc(); /// Load the data corresponding to the block arguments that /// represent input operands. - for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) { - if (!linalgOp.payloadUsesValueFromOperand(operand)) { + for (OpOperand &operand : linalgOp->getOpOperands()) { + if (!linalgOp.payloadUsesValueFromOperand(&operand)) { indexedValues.push_back(nullptr); continue; } - if (linalgOp.isScalar(operand)) { - indexedValues.push_back(operand->get()); + if (linalgOp.isScalar(&operand)) { + indexedValues.push_back(operand.get()); continue; } SmallVector indices = getIndicesForAccess( - builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(operand), ivs); + builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); Value load = - builder.create(linalgOpLoc, operand->get(), indices); + builder.create(linalgOpLoc, operand.get(), indices); indexedValues.push_back(load); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -272,10 +272,10 @@ b.setInsertionPointAfter(opToPad); // Make a copy of the shaped operands and update it. SmallVector newOperands; - newOperands.reserve(opToPad.getNumInputsAndOutputs()); - for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { + newOperands.reserve(opToPad->getNumOperands()); + for (OpOperand &opOperand : opToPad->getOpOperands()) { FailureOr paddedOperand = padOperandToSmallestStaticBoundingBox( - b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings); + b, opToPad, &opOperand, paddingDimensions, paddingValues, packPaddings); // Exit if `paddingDimensions` cannot be bounded statically. if (failed(paddedOperand)) return failure(); @@ -425,15 +425,15 @@ // Hoist the padding. for (const auto &en : enumerate(options.hoistPaddings)) { - if (static_cast(en.index()) >= paddedOp.getNumInputsAndOutputs()) + if (static_cast(en.index()) >= paddedOp->getNumOperands()) break; - OpOperand *opOperand = &paddedOp->getOpOperand(en.index()); - auto padOp = opOperand->get().getDefiningOp(); + OpOperand &opOperand = paddedOp->getOpOperand(en.index()); + auto padOp = opOperand.get().getDefiningOp(); if (!padOp || en.value() == 0) continue; // Fail hoisting if the operand shape is not fully static. - if (llvm::any_of(paddedOp.getShape(opOperand), ShapedType::isDynamic)) + if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) return failure(); tensor::PadOp hoistedOp; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -459,35 +459,35 @@ // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); Value zero = b.create(loc, 0); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber()); - if (linalgOp.isScalar(opOperand)) { - bvm.map(bbarg, opOperand->get()); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + BlockArgument bbarg = block->getArgument(opOperand.getOperandNumber()); + if (linalgOp.isScalar(&opOperand)) { + bvm.map(bbarg, opOperand.get()); continue; } VectorType readType; AffineMap map; // TODO: can we keep this simplification? - // if (linalgOp.getShape(opOperand).empty()) { + // if (linalgOp.getShape(&opOperand).empty()) { // readType = VectorType::get({}, bbarg.getType()); // } else { - if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { + if (opOperand.getOperandNumber() < linalgOp.getNumInputs()) { map = inverseAndBroadcastProjectedPermutation( - linalgOp.getMatchingIndexingMap(opOperand)); + linalgOp.getMatchingIndexingMap(&opOperand)); readType = VectorType::get(commonVectorShape, - getElementTypeOrSelf(opOperand->get())); + getElementTypeOrSelf(opOperand.get())); } else { map = inversePermutation( - reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand))); - readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), - getElementTypeOrSelf(opOperand->get())); + reindexIndexingMap(linalgOp.getMatchingIndexingMap(&opOperand))); + readType = VectorType::get(map.compose(linalgOp.getShape(&opOperand)), + getElementTypeOrSelf(opOperand.get())); } // } - auto shape = linalgOp.getShape(opOperand); + auto shape = linalgOp.getShape(&opOperand); SmallVector indices(shape.size(), zero); Value readValue = b.create( - loc, readType, opOperand->get(), indices, map); + loc, readType, opOperand.get(), indices, map); // Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. if (readValue.getType().cast().getRank() == 0) @@ -495,7 +495,7 @@ LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); bvm.map(bbarg, readValue); - bvm.map(opOperand->get(), readValue); + bvm.map(opOperand.get(), readValue); } SmallVector hooks; @@ -538,12 +538,12 @@ LDBG("reduction precondition failed: no reduction iterator"); return failure(); } - for (OpOperand *opOperand : op.getOutputOperands()) { - AffineMap indexingMap = op.getMatchingIndexingMap(opOperand); + for (OpOperand &opOperand : op.getOutputOperands()) { + AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand); if (indexingMap.isPermutation()) continue; - Operation *reduceOp = matchLinalgReduction(opOperand); + Operation *reduceOp = matchLinalgReduction(&opOperand); if (!reduceOp || !getCombinerOpKind(reduceOp)) { LDBG("reduction precondition failed: reduction detection failed"); return failure(); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -179,8 +179,8 @@ return false; // TODO: relax the restrictions on indexing map. - for (OpOperand *opOperand : op.getOutputOperands()) { - if (!op.getMatchingIndexingMap(opOperand).isPermutation()) + for (OpOperand &opOperand : op.getOutputOperands()) { + if (!op.getMatchingIndexingMap(&opOperand).isPermutation()) return false; } return hasOnlyScalarElementwiseOp(op->getRegion(0)); @@ -490,19 +490,20 @@ assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && "expected as many entries for proc info as number of loops, even if " "they are null entries"); - SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); + SmallVector iterArgInitValues = linalgOp.hasBufferSemantics() + ? SmallVector{} + : linalgOp.getOutputs(); SmallVector lbs, ubs, steps; unpackRanges(b, loc, loopRanges, lbs, ubs, steps); LoopNest loopNest = mlir::scf::buildLoopNest( b, loc, lbs, ubs, steps, iterArgInitValues, [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { - assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() && + assert(iterArgs.size() == iterArgInitValues.size() && "expect the number of output tensors and iter args to match"); - SmallVector operandValuesToUse = - linalgOp.getInputAndOutputOperands(); + SmallVector operandValuesToUse = linalgOp->getOperands(); if (!iterArgs.empty()) { - operandValuesToUse = linalgOp.getInputOperands(); + operandValuesToUse = linalgOp.getInputs(); operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); } return bodyBuilderFn(b, loc, ivs, operandValuesToUse); @@ -530,7 +531,9 @@ ValueRange)> bodyBuilderFn, ArrayRef /*procInfo*/) { - SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); + SmallVector iterArgInitValues = linalgOp.hasBufferSemantics() + ? SmallVector{} + : linalgOp.getOutputs(); assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); SmallVector lbs, ubs, steps; unpackRanges(b, loc, loopRanges, lbs, ubs, steps); @@ -546,9 +549,8 @@ mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, [&](OpBuilder &b, Location loc, ValueRange ivs) { - SmallVector operandValuesToUse = - linalgOp.getInputAndOutputOperands(); - bodyBuilderFn(b, loc, ivs, operandValuesToUse); + bodyBuilderFn(b, loc, ivs, + linalgOp->getOperands()); }); } @@ -695,7 +697,9 @@ ValueRange)> bodyBuilderFn, ArrayRef procInfo) { - SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); + SmallVector iterArgInitValues = linalgOp.hasBufferSemantics() + ? SmallVector{} + : linalgOp.getOutputs(); assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); // This function may be passed more iterator types than ranges. assert(iteratorTypes.size() >= loopRanges.size() && @@ -725,9 +729,7 @@ generateParallelLoopNest( b, loc, lbs, ubs, steps, iteratorTypes, procInfo, [&](OpBuilder &b, Location loc, ValueRange ivs) { - SmallVector operandValuesToUse = - linalgOp.getInputAndOutputOperands(); - bodyBuilderFn(b, loc, ivs, operandValuesToUse); + bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); }, ivs); @@ -905,25 +907,27 @@ } SmallVector getTensorOutputTypes(LinalgOp op, ValueRange operands) { - // TODO: use an interface/adaptor to avoid leaking position in - // `tiledOperands`. + if (op.hasBufferSemantics()) + return {}; return llvm::to_vector( - llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) { - return operands[opOperand->getOperandNumber()].getType(); + llvm::map_range(op.getOutputOperands(), [&](OpOperand &opOperand) { + return operands[opOperand.getOperandNumber()].getType(); })); } SmallVector insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results) { + if (op.hasBufferSemantics()) + return {}; SmallVector tensorResults; tensorResults.reserve(results.size()); // Insert a insert_slice for each output tensor. unsigned resultIdx = 0; - for (OpOperand *opOperand : op.getOutputTensorOperands()) { + for (OpOperand &opOperand : op.getOutputOperands()) { // TODO: use an interface/adaptor to avoid leaking position in // `tiledOperands`. - Value outputTensor = operands[opOperand->getOperandNumber()]; + Value outputTensor = operands[opOperand.getOperandNumber()]; if (auto sliceOp = outputTensor.getDefiningOp()) { Value inserted = builder.create( loc, sliceOp.getSource().getType(), results[resultIdx], @@ -958,23 +962,26 @@ computeTileSizes(builder, loc, tileSizes, sizeBounds); assert(static_cast(valuesToTile.size()) == - linalgOp.getNumInputsAndOutputs() && + linalgOp->getNumOperands() && "expected one value to tile for every operand"); SmallVector> allSliceParams; allSliceParams.reserve(valuesToTile.size()); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - Value shapedOp = valuesToTile[opOperand->getOperandNumber()]; + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + Value shapedOp = valuesToTile[opOperand.getOperandNumber()]; LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); - AffineMap map = linalgOp.getMatchingIndexingMap(opOperand); + AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); // Use `opOperand` as is if it is not tiled and not an output tensor. Having // an extract/insert slice pair for all output tensors simplifies follow up // transformations such as padding and bufferization since the // extract/insert slice pairs make the accessed iteration argument // subdomains explicit. - if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) { + + Type operandType = opOperand.get().getType(); + if (!isTiled(map, tileSizes) && !(operandType.isa() && + linalgOp.isOutput(&opOperand))) { allSliceParams.push_back(llvm::None); - LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: " - << opOperand->get().getType() << "\n"); + LLVM_DEBUG(llvm::dbgs() + << ": not tiled: use shape: " << operandType << "\n"); continue; } LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -104,8 +104,7 @@ auto yieldOp = cast(op.getRegion().front().getTerminator()); if (auto arg = yieldOp.getOperand(0).dyn_cast()) { if (arg.getOwner()->getParentOp() == op) { - OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()]; - return isZeroValue(t->get()); + return isZeroValue(op->getOperand(arg.getArgNumber())); } } return isZeroValue(yieldOp.getOperand(0)); @@ -223,8 +222,8 @@ return failure(); // Modify operand structure of producer and consumer. Location loc = prod.getLoc(); - SmallVector inputOps = prod.getInputOperands(); - SmallVector outputOps = op.getOutputOperands(); + SmallVector inputOps = prod.getInputs(); + SmallVector outputOps = op.getOutputs(); SmallVector fusedIndexMaps = prod.getIndexingMapsArray(); inputOps.push_back(op.getInputOperand(1 - other)->get()); fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -194,14 +194,14 @@ /// no annotations are found or inadmissible constructs occur. static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { bool annotated = false; - for (OpOperand *t : op.getInputAndOutputOperands()) { - auto map = op.getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); + for (OpOperand &t : op->getOpOperands()) { + auto map = op.getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; - assert(map.getNumResults() == op.getRank(t)); + assert(map.getNumResults() == op.getRank(&t)); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - unsigned tensor = t->getOperandNumber(); + unsigned tensor = t.getOperandNumber(); AffineExpr a = map.getResult(toOrigDim(enc, d)); if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d))) return false; // inadmissible affine expression @@ -291,13 +291,13 @@ std::vector inDegree(n, 0); // in-degree of each node. auto iteratorTypes = op.getIteratorTypesArray(); // Iterate over the indexing maps of every tensor in the tensor expression. - for (OpOperand *t : op.getInputAndOutputOperands()) { + for (OpOperand &t : op->getOpOperands()) { // Skip tensor during cycle resolution. - if (t == skip) + if (&t == skip) continue; // Get map and encoding. - auto map = op.getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); + auto map = op.getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); assert(map.getNumDims() == n); // Skip dense tensor constraints when not requested. if (!(mask & SortMask::kIncludeDense) && !enc) @@ -314,7 +314,7 @@ // Push unrelated loops into sparse iteration space, so these // will be skipped more often. if (mask & SortMask::kIncludeUndef) { - unsigned tensor = t->getOperandNumber(); + unsigned tensor = t.getOperandNumber(); for (unsigned i = 0; i < n; i++) if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) { @@ -534,16 +534,16 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op) { Location loc = op.getLoc(); - assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); + assert(op->getNumOperands() == op.getNumInputs() + 1); // For every tensor, find lower and upper bound on dimensions, set the // same bounds on loop indices, and obtain dense or sparse buffer(s). auto dynShape = {ShapedType::kDynamicSize}; SmallVector args; - for (OpOperand *t : op.getInputAndOutputOperands()) { - unsigned tensor = t->getOperandNumber(); - auto shape = op.getShape(t); - auto map = op.getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); + for (OpOperand &t : op->getOpOperands()) { + unsigned tensor = t.getOperandNumber(); + auto shape = op.getShape(&t); + auto map = op.getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); // Scan all dimensions of current tensor. args.clear(); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { @@ -560,23 +560,23 @@ MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); auto dim = builder.getIndexAttr(d); codegen.pointers[tensor][idx] = - builder.create(loc, ptrTp, t->get(), dim); + builder.create(loc, ptrTp, t.get(), dim); codegen.indices[tensor][idx] = - builder.create(loc, indTp, t->get(), dim); + builder.create(loc, indTp, t.get(), dim); } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) { // Singleton dimension, fetch indices. auto indTp = MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); auto dim = builder.getIndexAttr(d); codegen.indices[tensor][idx] = - builder.create(loc, indTp, t->get(), dim); + builder.create(loc, indTp, t.get(), dim); } else { // Dense dimension, nothing to fetch. assert(merger.isDimLevelType(tensor, idx, DimLvlType::kDense)); } // Find upper bound in current dimension. unsigned p = toOrigDim(enc, d); - Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p); + Value up = linalg::createOrFoldDimOp(builder, loc, t.get(), p); if (ShapedType::isDynamic(shape[p])) args.push_back(up); assert(codegen.highs[tensor][idx] == nullptr); @@ -585,21 +585,21 @@ // Perform the required bufferization. Dense inputs materialize // from the input tensors. Dense outputs need special handling. // Sparse inputs use sparse primitives to obtain the values. - Type elementType = getElementTypeOrSelf(t->get().getType()); + Type elementType = getElementTypeOrSelf(t.get().getType()); if (!enc) { // Non-annotated dense tensors. auto denseTp = MemRefType::get(shape, elementType); if (tensor < op.getNumInputs()) codegen.buffers[tensor] = - builder.create(loc, denseTp, t->get()); + builder.create(loc, denseTp, t.get()); else codegen.buffers[tensor] = genOutputBuffer(codegen, builder, op, denseTp, args); - } else if (t != codegen.sparseOut) { + } else if (&t != codegen.sparseOut) { // Annotated sparse tensors (not involved in output). auto sparseTp = MemRefType::get(dynShape, elementType); codegen.buffers[tensor] = - builder.create(loc, sparseTp, t->get()); + builder.create(loc, sparseTp, t.get()); } } } @@ -845,15 +845,15 @@ return val; } // Load during insertion. - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; - if (t == codegen.sparseOut) { + OpOperand &t = op->getOpOperand(merger.exp(exp).tensor); + if (&t == codegen.sparseOut) { if (codegen.redCustom != -1u) - return genInsertionLoadReduce(merger, codegen, builder, op, t); - return genInsertionLoad(codegen, builder, op, t); + return genInsertionLoadReduce(merger, codegen, builder, op, &t); + return genInsertionLoad(codegen, builder, op, &t); } // Actual load. SmallVector args; - Value ptr = genSubscript(codegen, builder, op, t, args); + Value ptr = genSubscript(codegen, builder, op, &t, args); if (codegen.curVecLength > 1) return genVectorLoad(codegen, builder, ptr, args); return builder.create(op.getLoc(), ptr, args); @@ -1093,9 +1093,9 @@ if (merger.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; - auto map = op.getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); + OpOperand &t = op->getOpOperand(merger.exp(exp).tensor); + auto map = op.getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { AffineExpr a = map.getResult(toOrigDim(enc, d)); if (!isInvariantAffine(codegen, a, ldx, atLevel)) @@ -1105,7 +1105,7 @@ if (!atLevel) return; OpOperand *lhs = op.getOutputOperand(0); - if (lhs == t) { + if (lhs == &t) { // Start or end a scalarized reduction if (atStart) { Kind kind = merger.exp(last).kind; @@ -1288,9 +1288,9 @@ /// This prevents effective vectorization. static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, unsigned idx) { - for (OpOperand *t : op.getInputAndOutputOperands()) { - if (!getSparseTensorEncoding(t->get().getType())) { - auto map = op.getMatchingIndexingMap(t); + for (OpOperand &t : op->getOpOperands()) { + if (!getSparseTensorEncoding(t.get().getType())) { + auto map = op.getMatchingIndexingMap(&t); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { AffineExpr a = map.getResult(d); // Report non-unit stride if innermost index appears at an outer @@ -1856,7 +1856,7 @@ // information for all tensors to loop indices in the kernel. if (op.getNumOutputs() != 1) return failure(); - unsigned numTensors = op.getNumInputsAndOutputs(); + unsigned numTensors = op->getNumOperands(); unsigned numLoops = op.iterator_types().getValue().size(); Merger merger(numTensors, numLoops); if (!findSparseAnnotations(merger, op)) @@ -1919,12 +1919,12 @@ // sparse input tensor in succession until an acylic // iteration graph results. std::vector topSort; - for (OpOperand *t : op.getInputOperands()) { - unsigned tensor = t->getOperandNumber(); - Value tval = t->get(); + for (OpOperand &t : op.getInputOperands()) { + unsigned tensor = t.getOperandNumber(); + Value tval = t.get(); auto srcEnc = getSparseTensorEncoding(tval.getType()); - if (!srcEnc || - !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly, t)) + if (!srcEnc || !computeIterationGraph(merger, op, topSort, + SortMask::kSparseOnly, &t)) continue; // Found an input tensor that resolves the cycle by inserting a // conversion into a sparse tensor that adheres to the iteration @@ -1936,7 +1936,7 @@ auto srcTp = tval.getType().cast(); auto dstEnc = SparseTensorEncodingAttr::get( op->getContext(), srcEnc.getDimLevelType(), - permute(getContext(), op.getMatchingIndexingMap(t), + permute(getContext(), op.getMatchingIndexingMap(&t), topSort), // new order srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth()); auto dstTp = RankedTensorType::get(srcTp.getShape(), diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -895,10 +895,10 @@ // argument is considered a tensor, indexed by the implicit loop // bounds. This includes rank-0 tensor arguments. if (arg.getOwner()->getParentOp() == op) { - OpOperand *t = op.getInputAndOutputOperands()[argN]; - if (!op.isScalar(t)) + OpOperand &t = op->getOpOperand(argN); + if (!op.isScalar(&t)) return addExp(kTensor, argN); - v = t->get(); // get scalar value + v = t.get(); // get scalar value } // Any other argument (marked as scalar argument for the generic op // or belonging to an enveloping op) is considered invariant. diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -275,7 +275,7 @@ // ----- // CHECK-LABEL: func @remove_deadargs_generic_basic -// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { // CHECK: %[[GENERIC_OP:.*]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]] : tensor) // CHECK-SAME: outs({{.*}} : tensor) { diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -121,26 +121,6 @@ // CHECK-SAME: outs({{.*}} : memref>) // CHECK-SAME: {foo = 1 : i64} -func.func @generic_with_tensor_input(%arg0: tensor>, - %arg1: memref>) { - %cst = arith.constant 0.0 : f32 - linalg.generic #trait_0 - ins(%arg0, %cst : tensor>, f32) - outs(%arg1 : memref>) - attrs = {foo = 1} { - ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) : - linalg.yield %1 : f32 - } - return -} -// CHECK-LABEL: func @generic_with_tensor_input -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], -// CHECK-SAME: library_call = "some_external_function_name_1"} -// CHECK-SAME: ins({{.*}}, {{.*}} : tensor>, f32) -// CHECK-SAME: outs({{.*}} : memref>) -// CHECK-SAME: {foo = 1 : i64} - // ----- #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -300,27 +280,19 @@ func.func @named_ops(%a3: memref, %b3: memref, %c3: memref, %ta3: tensor, %tb3: tensor, %tc3: tensor) - -> (tensor, tensor) + -> (tensor) { linalg.batch_matmul ins(%a3, %b3: memref, memref) outs(%c3: memref) - linalg.batch_matmul ins(%ta3, %tb3: tensor, tensor) - outs(%c3: memref) %res1 = linalg.batch_matmul ins(%ta3, %tb3: tensor, tensor) outs(%tc3: tensor) -> tensor - %res2 = linalg.batch_matmul - ins(%ta3, %b3: tensor, memref) - outs(%tc3: tensor) - -> tensor - return %res1, %res2 : tensor, tensor + return %res1 : tensor } // CHECK-LABEL: func @named_ops // CHECK: linalg.batch_matmul // CHECK: linalg.batch_matmul -// CHECK: linalg.batch_matmul -// CHECK: linalg.batch_matmul // ----- diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -26,7 +26,7 @@ return; TypeSwitch(op) .Case([&](linalg::LinalgOp linalgOp) { - SmallVector inputOperands = linalgOp.getInputOperands(); + OperandRange inputOperands = linalgOp.getInputs(); operandSet.insert(inputOperands.begin(), inputOperands.end()); }) .Default([&](Operation *operation) { @@ -144,7 +144,7 @@ if (expandOp->hasOneUse()) { OpOperand &use = *expandOp->getUses().begin(); auto linalgOp = dyn_cast(use.getOwner()); - if (linalgOp && linalgOp.isOutputTensor(&use)) + if (linalgOp && linalgOp.isOutput(&use)) return true; } return false; diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -38,14 +38,14 @@ // Tile and Fuse for tensors inputs (TODO: all tensor operands). bool changed = false; for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - if (opOperand->get().getType().isa()) { + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + if (opOperand.get().getType().isa()) { // TODO: LinalgDependenceGraph should be able to update itself. // The current naive and expensive reconstruction of the graph should be // removed. linalg::Aliases aliases; linalg::LinalgDependenceGraph graph(aliases, linalgOps); - auto info = fuseProducerOfBuffer(b, *opOperand, graph); + auto info = fuseProducerOfBuffer(b, opOperand, graph); if (failed(info)) continue; auto *originalOp = info->originalProducer.getOperation(); @@ -54,11 +54,11 @@ std::find(linalgOps.begin(), linalgOps.end(), originalOp); *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); changed = true; - } else if (opOperand->get().getType().isa()) { + } else if (opOperand.get().getType().isa()) { // Tile and Fuse tensor input. - if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) + if (opOperand.getOperandNumber() >= linalgOp.getNumInputs()) continue; - auto info = fuseProducerOfTensor(b, *opOperand); + auto info = fuseProducerOfTensor(b, opOperand); if (failed(info)) continue; auto *originalOp = info->originalProducer.getOperation(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2829,8 +2829,9 @@ } // To conform with interface requirement on operand naming. - mlir::ValueRange inputs() { return getInputs(); } - mlir::ValueRange outputs() { return getOutputs(); } + mlir::OperandRange inputs() { return getInputs(); } + mlir::OperandRange outputs() { return getOutputs(); } + int64_t getNumOutputs() { return 1; } }]; } @@ -2888,8 +2889,9 @@ } // To conform with interface requirement on operand naming. - mlir::ValueRange inputs() { return getInputs(); } - mlir::ValueRange outputs() { return getOutputs(); } + mlir::OperandRange inputs() { return getInputs(); } + mlir::OperandRange outputs() { return getOutputs(); } + int64_t getNumOutputs() { return 1; } }]; } diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -563,6 +563,10 @@ return regionBuilder; } + int64_t getNumOutputs() {{ + return 1; + } + // Generic methods. static unsigned getNumRegionArgs(); std::string getLibraryCallName(); @@ -638,8 +642,8 @@ AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( getNumParallelLoops(), context); SmallVector indexingMaps; - for (OpOperand *opOperand : getInputAndOutputOperands()) - indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap); + for (OpOperand &opOperand : getOperation()->getOpOperands()) + indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap); return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); } )FMT"; @@ -654,10 +658,9 @@ } void {0}::getEffects(SmallVectorImpl< SideEffects::EffectInstance >&effects) {{ - SmallVector inputBuffers = getInputBufferOperands(); - SmallVector outputBuffers = getOutputBufferOperands(); + if (hasTensorSemantics()) return; getGenericEffectsImpl(effects, - getOperation()->getResults(), inputBuffers, outputBuffers); + getOperation()->getResults(), getInputs(), getOutputs()); } )FMT";