diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -787,57 +787,24 @@ return inversePermutation(getLoopsToShapesMap()); }] >, - InterfaceMethod< - /*desc=*/[{ - Return the position in the results of the affine map computed - by getLoopsToShapesMap() that represents the shape of an - operand (input or output) at a dimension. - }], - /*retTy=*/"Optional", - /*methodName=*/"getOperandDimPositionInLoopsToShapeMap", - /*args=*/(ins "unsigned":$operandIdx, "unsigned":$dim), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - unsigned pos = 0; - for (OpOperand *opOperand : getInputAndOutputOperands()) { - if (opOperand->getOperandNumber() == operandIdx) return pos + dim; - pos += getRank(opOperand); - } - return {}; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the position in the results of the affine map computed - by getLoopsToShapesMap() that represents the shape of an - input operand at a dimension. - }], - /*retTy=*/"Optional", - /*methodName=*/"getInputValueDimPositionInLoopsToShapeMap", - /*args=*/(ins "unsigned":$inputIdx, "unsigned":$dim), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - if (inputIdx >= getNumInputs()) return {}; - return getOperandDimPositionInLoopsToShapeMap(inputIdx, dim); - }] - >, InterfaceMethod< /*desc=*/[{ Return the range of position in the result of the affine map computed by getLoopsToShapesMap() which correspond to the AffineExprs used to access the outputs of the operation. }], - /*retTy=*/"std::pair", + /*retTy=*/"std::pair", /*methodName=*/"getResultsPositionInLoopsToShapeMap", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperand *opOperand = getOutputOperand(getNumOutputs()-1); - return - {*getOperandDimPositionInLoopsToShapeMap(getNumInputs(), 0), - (*getOperandDimPositionInLoopsToShapeMap - (getNumInputs() + getNumOutputs() - 1, - getRank(opOperand) - 1)) + 1}; + int64_t inputRankSum = 0; + int64_t outputRankSum = 0; + for(OpOperand *input : getInputOperands()) + inputRankSum += getRank(input); + for(OpOperand *output : getOutputOperands()) + outputRankSum += getRank(output); + return {inputRankSum, inputRankSum + outputRankSum}; }] >, InterfaceMethod<