diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1424,7 +1424,7 @@ /// `linalgOp`. Operation::operand_range getAssumedNonShapedOperands() { Operation::operand_range res{ - getOperation()->getOperands().begin() + getNumShapedOperands(), + getOperation()->getOperands().begin() + getNumInputsAndOutputs(), getOperation()->getOperands().end()}; for (Type t : TypeRange{res}) { (void)t; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -140,7 +140,7 @@ // Rank-polymorphic. // filling_value -> O(ivs) with parallel iterators. ArrayAttr iterator_types() { - unsigned nPar = getInputShapedType(0).getRank(); + int64_t nPar = getRank(getInputOperand(0)); return Builder(getContext()).getStrArrayAttr( SmallVector(nPar, getParallelIteratorTypeName())); } @@ -150,8 +150,8 @@ MLIRContext *context = getContext(); auto maybeInputMap = inputPermutation(); auto maybeOutputMap = outputPermutation(); - unsigned inputRank = getInputShapedType(0).getRank(); - unsigned outputRank = getOutputShapedType(0).getRank(); + int64_t inputRank = getRank(getInputOperand(0)); + int64_t outputRank = getRank(getOutputOperand(0)); return Builder(getContext()).getAffineMapArrayAttr({ extractOrIdentityMap(maybeInputMap, inputRank, context), extractOrIdentityMap(maybeOutputMap, outputRank, context)}); @@ -195,7 +195,7 @@ // Rank-polymorphic. // filling_value -> O(ivs) with parallel iterators. ArrayAttr iterator_types() { - unsigned nPar = getOutputShapedType(0).getRank(); + int64_t nPar = getRank(getOutputOperand(0)); return Builder(getContext()).getStrArrayAttr( SmallVector(nPar, getParallelIteratorTypeName())); } @@ -351,14 +351,14 @@ unsigned getNumOutputFeatureDimensions() { return 1; } unsigned getNumSpatialDimensions() { - return getOutputShapedType(0).getRank() - getNumBatchDimensions() - + return getRank(getOutputOperand(0)) - getNumBatchDimensions() - getNumOutputFeatureDimensions(); } ArrayAttr iterator_types() { // Outer parallel loops are always the number of output dimensions; i.e. // [b, xs, q] in the TF notation above. - unsigned nPar = getOutputShapedType(0).getRank(); + int64_t nPar = getRank(getOutputOperand(0)); unsigned nRed = getNumInputFeatureDimensions(); // Window loops are a special kind of reduction that is never tiled or // parallelized across; i.e. [zs] in the TF notation above whose number @@ -457,7 +457,7 @@ ArrayAttr iterator_types() { // Outer parallel loops are always the number of output dimensions. - unsigned nPar = getOutputShapedType(0).getRank(); + int64_t nPar = getRank(getOutputOperand(0)); // The window loops has the same number loops with output dimensions. unsigned nWin = nPar; SmallVector iters(nPar, getParallelIteratorTypeName()); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -194,21 +194,18 @@ SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, Location loc) { SmallVector res; - for (Value v : getShapedOperands()) { - ShapedType t = v.getType().template cast(); - for (unsigned i = 0, e = t.getRank(); i < e; ++i) - res.push_back(b.createOrFold(loc, v, i)); + for (OpOperand *opOperand : getInputAndOutputOperands()) { + for (int64_t i = 0, e = getRank(opOperand); i < e; ++i) + res.push_back(b.createOrFold(loc, opOperand->get(), i)); } return res; } SmallVector LinalgOp::createFlatListOfOperandStaticDims() { SmallVector res; - for (Value v : getShapedOperands()) { - ShapedType t = v.getType().template cast(); - assert(t.hasStaticShape() && "expected operands to have static shapes"); - llvm::append_range(res, t.getShape()); - } + assert(!hasDynamicShape() && "expected operands to have static shapes"); + for (OpOperand *opOperand : getInputAndOutputOperands()) + llvm::append_range(res, getShape(opOperand)); return res; } @@ -302,15 +299,14 @@ auto allResultDimValues = applyMapToValues(b, loc, resultShapesFromInputShapesMap, createFlatListOfOperandDims(b, loc)); - unsigned pos = 0; + int64_t pos = 0; ArrayRef shapeExprs = resultShapesFromInputShapesMap.getResults(); - for (auto resultIdx : llvm::seq(0, getNumOutputs())) { - ShapedType resultType = getOutputShapedType(resultIdx); + for (OpOperand *opOperand : getOutputOperands()) { SmallVector shapes; - for (unsigned dim : llvm::seq(0, resultType.getRank())) { + for (int64_t dim : llvm::seq(0, getRank(opOperand))) { if (checkDimExpr.visit(shapeExprs[pos])) shapes.push_back( - b.createOrFold(loc, getOutput(resultIdx), dim)); + b.createOrFold(loc, opOperand->get(), dim)); else shapes.push_back(allResultDimValues[pos]); pos++; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -699,15 +699,19 @@ void GenericOp::getEffects( SmallVectorImpl> &effects) { - getGenericEffectsImpl(effects, getOperation()->getResults(), - getInputBuffers(), getOutputBuffers()); + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); + getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, + outputBuffers); } void IndexedGenericOp::getEffects( SmallVectorImpl> &effects) { - getGenericEffectsImpl(effects, getOperation()->getResults(), - getInputBuffers(), getOutputBuffers()); + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); + getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, + outputBuffers); } template diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -2107,8 +2107,10 @@ } void {0}::getEffects(SmallVectorImpl< SideEffects::EffectInstance >&effects) {{ + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); getGenericEffectsImpl(effects, - getOperation()->getResults(), getInputBuffers(), getOutputBuffers()); + getOperation()->getResults(), inputBuffers, outputBuffers); })FMT"; os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName); } diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -551,8 +551,10 @@ } void {0}::getEffects(SmallVectorImpl< SideEffects::EffectInstance >&effects) {{ - getGenericEffectsImpl(effects, - getOperation()->getResults(), getInputBuffers(), getOutputBuffers()); + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); + getGenericEffectsImpl(effects, + getOperation()->getResults(), inputBuffers, outputBuffers); } )FMT";