diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -31,6 +31,14 @@ operator SmallVector(); }; +namespace detail { +/// Implementation of the method that that check if given operands +/// can be dropped, i.e. the remaining operands can compute the loop +/// bounds of the op. +bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, + ArrayRef droppedOperands); +} // namespace detail + /// Checks whether `linalgOp` conforms to ContractionOpInterface. // TODO: embed within `isa` if possible / natural. bool isaContractionOpInterface(LinalgOp linalgOp); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -958,6 +958,19 @@ return inversePermutation(getLoopsToShapesMap()); }] >, + InterfaceMethod< + /*desc=*/[{ + Checks if the given operands can be dropped, and the remaining + operands can still compute the bounds of the op. + }], + /*retTy=*/"bool", + /*methodName=*/"canOpOperandsBeDropped", + /*args=*/(ins "ArrayRef":$droppedOperands), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands); + }] + >, InterfaceMethod< /*desc=*/[{ Return the range of position in the result of the affine map diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -165,6 +165,12 @@ let regions = (region AnyRegion:$region); let builders = [ + OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, "ArrayAttr":$indexingMaps, + "ArrayAttr":$iteratorTypes, "StringAttr":$doc, + "StringAttr":$libraryCall, + "function_ref", + CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, "StringRef":$doc, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -23,6 +23,20 @@ /// Include the definitions of the copy operation interface. #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" +//===----------------------------------------------------------------------===// +// Interface utility functions +//===----------------------------------------------------------------------===// +bool linalg::detail::canOpOperandsBeDroppedImpl( + linalg::LinalgOp linalgOp, ArrayRef droppedOperands) { + SmallVector indexingMaps; + for (auto opOperand : linalgOp.getInputAndOutputOperands()) { + if (llvm::is_contained(droppedOperands, opOperand)) + continue; + indexingMaps.push_back(linalgOp.getTiedIndexingMap(opOperand)); + } + return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); +} + //===----------------------------------------------------------------------===// // ContractionOpInterface implementation //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -266,33 +267,6 @@ return success(folded); } -/// Helper function to find if there is atleast one dimension in an AffineMap -/// testMap that is contained in `testMapLocation` of `maps` but not in any -/// other locations -static bool hasaUniqueDim(ArrayRef maps, unsigned testMapLocation) { - AffineMap testMap = maps[testMapLocation]; - llvm::SmallDenseSet dimsToCheck; - for (auto result : testMap.getResults()) { - auto expr = result.dyn_cast(); - if (expr != nullptr) - dimsToCheck.insert(expr.getPosition()); - } - for (const auto &it : llvm::enumerate(maps)) { - if (it.index() == testMapLocation) - continue; - auto map = it.value(); - for (auto result : map.getResults()) { - auto expr = result.dyn_cast(); - if (expr != nullptr) { - dimsToCheck.erase(expr.getPosition()); - } - if (dimsToCheck.empty()) - return false; - } - } - return true; -} - //===----------------------------------------------------------------------===// // Region builder helper. // TODO: Move this to a utility library. @@ -670,16 +644,12 @@ //===----------------------------------------------------------------------===// void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, - ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, - ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, + ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps, + ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall, function_ref bodyBuild, ArrayRef attributes) { - build(builder, result, resultTensorTypes, inputs, outputs, - builder.getAffineMapArrayAttr(indexingMaps), - builder.getStrArrayAttr(iteratorTypes), - doc.empty() ? StringAttr() : builder.getStringAttr(doc), - libraryCall.empty() ? StringAttr() - : builder.getStringAttr(libraryCall)); + build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, + iteratorTypes, doc, libraryCall); result.addAttributes(attributes); if (!bodyBuild) return; @@ -700,6 +670,20 @@ bodyBuild(builder, result.location, bodyBlock->getArguments()); } +void GenericOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, + ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, + function_ref bodyBuild, + ArrayRef attributes) { + build(builder, result, resultTensorTypes, inputs, outputs, + builder.getAffineMapArrayAttr(indexingMaps), + builder.getStrArrayAttr(iteratorTypes), + doc.empty() ? StringAttr() : builder.getStringAttr(doc), + libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), + bodyBuild, attributes); +} + void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, @@ -844,93 +828,164 @@ LogicalResult GenericOp::verify() { return success(); } namespace { -// Deduplicate redundant args of a linalg generic op. -// An arg is redundant if it has the same Value and indexing map as another. -struct DeduplicateGenericOpInputs : public OpRewritePattern { + +struct DeduplicateAndRemoveDeadOperandsAndResults + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - // Associate each input to an equivalent "canonical" input that has the same - // Value and indexing map. - // - // In the non-duplicate case, input `i` will have canonical input `i`. But - // in the case of duplicated inputs, the canonical input could be some other - // input `< i`. That is, a later input will have some earlier input as its - // canonical input. - llvm::SmallDenseMap, unsigned> canonicalInput; - // For later remapping tasks like deduplicating payload block arguments, - // having a simple "inputIndex -> canonicalInputIndex" integer mapping is - // convenient. - SmallVector canonicalInputIndices; - for (OpOperand *opOperand : genericOp.getInputOperands()) { - AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); - // STL-like maps have a convenient behavior for our use case here. In the - // case of duplicate keys, the insertion is rejected, and the returned - // iterator gives access to the value already in the map. - auto pair = canonicalInput.insert( - {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()}); - canonicalInputIndices.push_back(pair.first->second); + // Create a map from argument position in the original op to the argument + // position in the new op. If the argument is dropped it wont have an entry. + llvm::SmallDenseMap origToNewPos; + unsigned numNewArgs = 0; + SmallVector droppedOpOperands; + llvm::SmallDenseSet droppedOutputs; + + // Information needed to build the new op. + SmallVector newInputOperands, newOutputOperands; + SmallVector newIndexingMaps; + SmallVector newResultTypes; + + // Input argument can be dropped if + // - it has no uses, or, + // - there is a duplicate operand which is accessed using the same + // indexing map. + llvm::SmallDenseMap, unsigned> dedupedInputs; + auto indexingMaps = genericOp.getIndexingMaps(); + ArrayRef unprocessedIndexingMaps(indexingMaps); + for (OpOperand *inputOpOperand : genericOp.getInputOperands()) { + BlockArgument arg = genericOp.getTiedBlockArgument(inputOpOperand); + unsigned argNum = arg.getArgNumber(); + unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front(); + + // Check if operand is dead and if dropping the indexing map makes the + // loops to shape computation invalid. + if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) { + // Add the current operands to the list of potentially droppable + // operands. If it cannot be dropped, this needs to be popped back. + droppedOpOperands.push_back(inputOpOperand); + if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) + continue; + droppedOpOperands.pop_back(); + } + + // Check if this operand is a duplicate. + AffineMap indexingMap = genericOp.getTiedIndexingMap(inputOpOperand); + auto it = dedupedInputs.find( + std::make_pair(inputOpOperand->get(), indexingMap)); + if (it != dedupedInputs.end()) { + origToNewPos[argNum] = it->second; + droppedOpOperands.push_back(inputOpOperand); + continue; + } + + // This is a preserved argument. + origToNewPos[argNum] = numNewArgs++; + dedupedInputs[{inputOpOperand->get(), indexingMap}] = argNum; + newInputOperands.push_back(inputOpOperand->get()); + newIndexingMaps.push_back(indexingMap); } - // If there are no duplicate args, then bail out. - if (canonicalInput.size() == genericOp.getNumInputs()) - return failure(); + // If the op doesnt have tensor semantics, keep all the outputs as + // preserved. + if (!genericOp.hasTensorSemantics()) { + for (OpOperand *outputOpOperand : genericOp.getOutputOperands()) { + unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front(); + BlockArgument arg = genericOp.getTiedBlockArgument(outputOpOperand); + origToNewPos[arg.getArgNumber()] = numNewArgs++; + newOutputOperands.push_back(outputOpOperand->get()); + newIndexingMaps.push_back( + genericOp.getTiedIndexingMap(outputOpOperand)); + } + } else { + // Output argument can be dropped if the result has + // - no users, and + // - it is not used in the payload, and + // - the corresponding indexing maps are not needed for loop bound + // computation. + for (auto outputOpOperand : + llvm::enumerate(genericOp.getOutputOperands())) { + unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front(); + Value result = genericOp.getResult(outputOpOperand.index()); + BlockArgument arg = + genericOp.getTiedBlockArgument(outputOpOperand.value()); + if (result.use_empty() && + !genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { + // Check if the opoperand can be dropped without affecting loop bound + // computation. Add the operand to the list of dropped op operand for + // checking. If it cannot be dropped, need to pop the value back. + droppedOpOperands.push_back(outputOpOperand.value()); + if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { + droppedOutputs.insert(outputOpOperand.index()); + continue; + } + droppedOpOperands.pop_back(); + } - // The operands for the newly canonicalized op. - SmallVector newInputOperands; - for (OpOperand *opOperand : genericOp.getInputOperands()) - if (canonicalInputIndices[opOperand->getOperandNumber()] == - opOperand->getOperandNumber()) - newInputOperands.push_back(opOperand->get()); + origToNewPos[arg.getArgNumber()] = numNewArgs++; + newOutputOperands.push_back(outputOpOperand.value()->get()); + newIndexingMaps.push_back( + genericOp.getTiedIndexingMap(outputOpOperand.value())); + newResultTypes.push_back(result.getType()); + } + } - // Repair the indexing maps by filtering out the ones that have been - // eliminated. - SmallVector newIndexingMaps; - for (OpOperand *opOperand : genericOp.getInputOperands()) - if (canonicalInputIndices[opOperand->getOperandNumber()] == - opOperand->getOperandNumber()) - newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); - for (OpOperand *opOperand : genericOp.getOutputOperands()) - newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); - - // Clone the old op with new operands. - SmallVector outputOperands = genericOp.getOutputOperands(); + // Check if there is any change to operands. + if (newInputOperands.size() + newOutputOperands.size() == + static_cast(genericOp.getNumInputsAndOutputs())) + return failure(); + + // Create the new op with the body being empty. + Location loc = genericOp.getLoc(); auto newOp = rewriter.create( - genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, - outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), + loc, newResultTypes, newInputOperands, newOutputOperands, + rewriter.getAffineMapArrayAttr(newIndexingMaps), genericOp.iterator_types(), genericOp.docAttr(), - genericOp.library_callAttr()); - + genericOp.library_callAttr(), + [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) { + return; + }); // Copy over unknown attributes. They might be load bearing for some flow. ArrayRef odsAttrs = genericOp.getAttributeNames(); - for (NamedAttribute kv : genericOp->getAttrs()) { - if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { + for (NamedAttribute kv : genericOp->getAttrs()) + if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) newOp->setAttr(kv.getName(), kv.getValue()); - } - } - rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), - newOp.region().begin()); - - // Repair the payload entry block by RAUW'ing redundant arguments and - // erasing them. - Block &payload = newOp.region().front(); - SmallVector inputOperands = genericOp.getInputOperands(); - for (OpOperand *opOperand : llvm::reverse(inputOperands)) { - // Iterate in reverse, so that we erase later args first, preventing the - // argument list from shifting unexpectedly and invalidating all our - // indices. - unsigned operandNumber = opOperand->getOperandNumber(); - if (canonicalInputIndices[operandNumber] == operandNumber) - continue; - payload.getArgument(operandNumber) - .replaceAllUsesWith( - payload.getArgument(canonicalInputIndices[operandNumber])); - payload.eraseArgument(operandNumber); + // Merge the body of the original op with the new op. + Block *newOpBlock = &newOp.region().front(); + Block *origOpBlock = &genericOp.region().front(); + SmallVector replacements(origOpBlock->getNumArguments(), nullptr); + for (auto argNum : llvm::seq(0, origOpBlock->getNumArguments())) { + auto it = origToNewPos.find(argNum); + if (it != origToNewPos.end()) + replacements[argNum] = newOpBlock->getArgument(it->second); + } + rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements); + + // Drop the unused yield args. + Block *block = &newOp.region().front(); + if (!droppedOutputs.empty()) { + OpBuilder::InsertionGuard g(rewriter); + SmallVector newYieldVals; + YieldOp origYieldOp = cast(block->getTerminator()); + rewriter.setInsertionPoint(origYieldOp); + for (auto yieldOpOperands : llvm::enumerate(origYieldOp.values())) { + if (!droppedOutputs.count(yieldOpOperands.index())) { + newYieldVals.push_back(yieldOpOperands.value()); + continue; + } + } + rewriter.replaceOpWithNewOp(origYieldOp, newYieldVals); } - rewriter.replaceOp(genericOp, newOp->getResults()); + // Replace all live uses of the op. + SmallVector replacementsVals(genericOp->getNumResults(), nullptr); + unsigned newResultNum = 0; + for (auto result : llvm::enumerate(genericOp.getResults())) + if (!droppedOutputs.count(result.index())) + replacementsVals[result.index()] = newOp.getResult(newResultNum++); + rewriter.replaceOp(genericOp, replacementsVals); return success(); } }; @@ -1007,72 +1062,13 @@ return success(); } }; - -/// Drop dead args of a linalg generic op. -/// An arg is dead if it has zero uses in the op region. -struct DeadArgsGenericOpInputs : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - SmallVector oldIndexingMaps = genericOp.getIndexingMaps(); - // Maps must be projected permutations. - if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) { - return !map.isProjectedPermutation(); - })) - return failure(); - Block &payload = genericOp.region().front(); - SmallVector newInputOperands; - SmallVector newIndexingMaps; - bool deadArgFound = false; - int inputSize = genericOp.getInputOperands().size(); - for (int i = inputSize - 1; i >= 0; i--) { - OpOperand *opOperand = genericOp.getInputOperand(i); - // Iterate in reverse, so that we erase later args first, preventing the - // argument list from shifting unexpectedly and invalidating all our - // indices. - if (payload.getArgument(i).use_empty() && - !hasaUniqueDim(oldIndexingMaps, i)) { - payload.eraseArgument(i); - deadArgFound = true; - // remove this indexing map out of consideration for hasaUniqueDim check - oldIndexingMaps.erase(oldIndexingMaps.begin() + i); - } else { - newInputOperands.insert(newInputOperands.begin(), opOperand->get()); - newIndexingMaps.insert(newIndexingMaps.begin(), - genericOp.getTiedIndexingMap(opOperand)); - } - } - // Bail out if there are no dead args. - if (!deadArgFound) - return failure(); - for (OpOperand *opOperand : genericOp.getOutputOperands()) - newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); - SmallVector outputOperands = genericOp.getOutputOperands(); - - auto newOp = rewriter.create( - genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, - outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), - genericOp.iterator_types(), genericOp.docAttr(), - genericOp.library_callAttr()); - // Copy over unknown attributes. They might be load bearing for some flow. - ArrayRef odsAttrs = genericOp.getAttributeNames(); - for (NamedAttribute kv : genericOp->getAttrs()) { - if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { - newOp->setAttr(kv.getName(), kv.getValue()); - } - } - rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), - newOp.region().begin()); - rewriter.replaceOp(genericOp, newOp->getResults()); - return success(); - } -}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } LogicalResult GenericOp::fold(ArrayRef, diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir --- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir @@ -90,3 +90,121 @@ } -> tensor return %0 : tensor } + +// ----- + +// Drop dead result. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +func.func @drop_dead_results(%arg0 : tensor) -> (tensor, tensor) { + %0:4 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3, #map4], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %arg0, %arg0, %arg0 + : tensor, tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) : + %1 = arith.addf %b0, %b0: f32 + linalg.yield %1, %1, %1, %1 : f32, f32, f32, f32 + } -> (tensor, tensor, tensor, tensor) + return %0#0, %0#2 : tensor, tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +// CHECK: func @drop_dead_results( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: outs(%[[ARG0]], %[[ARG0]] : +// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 + +// ----- + +// Current argmax lowering to `linalg.generic`. Cannot drop the +// first return even though it isnt used since it has an internal +// use. +#map0 = affine_map<(d0) -> (d0)> +#map1 = affine_map<(d0) -> ()> +func.func @argmax_lowering(%arg0 : tensor) -> tensor { + %init0 = linalg.init_tensor [] : tensor + %init1 = linalg.init_tensor [] : tensor + %0:2 = linalg.generic { + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["reduction"]} + ins(%arg0 : tensor) + outs(%init0, %init1 : tensor, tensor) { + ^bb0(%b0: f32, %b1: f32, %b2: i32): + %8 = linalg.index 0 : index + %9 = arith.index_cast %8 : index to i32 + %10 = arith.cmpf oge, %b0, %b1 : f32 + %11 = arith.select %10, %b0, %b1 : f32 + %12 = arith.cmpf oeq, %b0, %b1 : f32 + %13 = arith.minsi %9, %b2 : i32 + %14 = arith.select %10, %9, %b2 : i32 + %15 = arith.select %12, %13, %14 : i32 + linalg.yield %11, %15 : f32, i32 + } -> (tensor, tensor) + return %0#1 : tensor +} +// CHECK: func @argmax_lowering( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [] : tensor +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [] : tensor +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: outs(%[[INIT0]], %[[INIT1]] : +// CHECK: return %[[GENERIC]]#1 + +// ----- + +// Do not remove operand needed for loop dim. +func.func @loop_dim_operand(%arg0 : tensor) -> tensor { + %cst = arith.constant 0 : i32 + %init = linalg.init_tensor [] : tensor + %fill = linalg.fill ins(%cst : i32) outs(%init : tensor) -> tensor + %0 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0 : tensor) outs(%fill : tensor) { + ^bb0(%b0: f32, %b1: i32): + %1 = linalg.index 0 : index + %2 = arith.index_cast %1 : index to i32 + %3 = arith.addi %b1, %2 : i32 + linalg.yield %3 : i32 + } -> tensor + return %0 : tensor +} +// CHECK: func @loop_dim_operand( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : + +// ----- + +// Do not remove outs operand needed for loop bound computation. +func.func @loop_dim_outs_operand(%arg0 : index) -> tensor { + %cst = arith.constant 0 : i32 + %init1 = linalg.init_tensor [%arg0] : tensor + %init = linalg.init_tensor [] : tensor + %fill = linalg.fill ins(%cst : i32) outs(%init : tensor) -> tensor + %0:2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], + iterator_types = ["parallel"]} + outs(%init1, %fill : tensor, tensor) { + ^bb0(%b0: i32, %b1: i32): + %1 = linalg.index 0 : index + %2 = arith.index_cast %1 : index to i32 + %3 = arith.addi %b1, %2 : i32 + linalg.yield %2, %3 : i32, i32 + } -> (tensor, tensor) + return %0#1 : tensor +} +// CHECK: func @loop_dim_outs_operand( +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[ARG0]]] +// CHECK: linalg.generic +// CHECK-SAME: outs(%[[INIT]]