diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -31,7 +31,10 @@ Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool", /*default=*/"false", "Only folds the one-trip loops from Linalg ops on tensors " - "(for testing purposes only)"> + "(for testing purposes only)">, + Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool", + /*default=*/"false", + "Generate rank-reducing slices instead of reassociative reshapes"> ]; let dependentDialects = [ "linalg::LinalgDialect", "AffineDialect", "memref::MemRefDialect" diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -129,8 +129,12 @@ void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on -/// tensors. -void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); +/// tensors via reassociative reshape ops. +void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns); + +/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on +/// tensors via rank-reducing slices. +void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns); /// Patterns that are used to inline constant operands into linalg generic ops. void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -54,6 +54,16 @@ /// single deallocate if it exists or nullptr. Optional findDealloc(Value allocValue); +/// Return the dimensions of the given memref value. +SmallVector getMixedSizes(OpBuilder &builder, Location loc, + Value value); + +/// Create a rank-reducing SubViewOp @[0 .. 0] with strides [1 .. 1] and +/// appropriate sizes (i.e. `memref.getSizes()`) to reduce the rank of `memref` +/// to that of `targetShape`. +Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, + Value memref, + ArrayRef targetShape); } // namespace memref } // namespace mlir diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1954,6 +1954,15 @@ /// Return the dimensions of the source type that are dropped when /// the result is rank-reduced. llvm::SmallBitVector getDroppedDims(); + + /// Given a `value`, asserted to be of MemRefType, build a SubViewOp that + /// results in a rank reduction to the desired memref shape and return the + /// new value created. + /// If the shape of `value` is already the `desiredShape`, just return + /// `value`. + /// If the shape of `value` cannot be rank-reduced to `desiredShape`, fail. + static FailureOr rankReduceIfNeeded( + OpBuilder &b, Location loc, Value value, ArrayRef desiredShape); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -42,6 +42,10 @@ using namespace mlir; using namespace mlir::linalg; +namespace { +enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice }; +} // namespace + /// Implements a pass that canonicalizes the uses of unit-extent dimensions for /// broadcasting. For example, /// @@ -349,9 +353,9 @@ }; struct UnitExtentReplacementInfo { - Type type; AffineMap indexMap; - ArrayAttr reassociation; + SmallVector reassociation; + SmallVector targetShape; }; } // namespace @@ -371,8 +375,6 @@ AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); ArrayRef shape = genericOp.getShape(opOperand); ArrayRef exprs = indexingMap.getResults(); - SmallVector reassociations; - SmallVector reassociationMaps; SmallVector newIndexExprs; SmallVector newShape; @@ -391,99 +393,110 @@ } int64_t dim = 0; + SmallVector reassociation; + ReassociationIndices reassociationGroup; // Fold dimensions that are unit-extent at the beginning of the tensor. while (dim < origRank && isUnitExtent(dim)) - reassociations.push_back(getAffineDimExpr(dim++, context)); + reassociationGroup.push_back(dim++); while (dim < origRank) { - reassociations.push_back(getAffineDimExpr(dim, context)); + assert(!isUnitExtent(dim) && "expected non unit-extent"); + reassociationGroup.push_back(dim); newIndexExprs.push_back(exprs[dim]); newShape.push_back(shape[dim]); - // Fold all following dimensions that are unit-extent. - while (dim + 1 < origRank && isUnitExtent(dim + 1)) { - ++dim; - reassociations.push_back(getAffineDimExpr(dim, context)); - } - reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( - origRank, /*symbolCount = */ 0, reassociations, context))); - reassociations.clear(); ++dim; + // Fold all following dimensions that are unit-extent. + while (dim < origRank && isUnitExtent(dim)) + reassociationGroup.push_back(dim++); + reassociation.push_back(reassociationGroup); + reassociationGroup.clear(); } - // Compute the tensor or scalar replacement type. - Type elementType = getElementTypeOrSelf(opOperand->get()); - Type replacementType; - if (elementType == opOperand->get().getType()) { - replacementType = elementType; - } else if (actualType.isa()) { - replacementType = RankedTensorType::get(newShape, elementType); - } else { - auto memrefType = actualType.cast(); - replacementType = MemRefType::get(newShape, elementType, {}, - memrefType.getMemorySpaceAsInt()); - } - UnitExtentReplacementInfo info = {replacementType, - AffineMap::get(indexingMap.getNumDims(), - indexingMap.getNumSymbols(), - newIndexExprs, context), - ArrayAttr::get(context, reassociationMaps)}; + // Return if the rank was not reduced. + if (origRank == static_cast(newShape.size())) + return std::nullopt; + + UnitExtentReplacementInfo info = { + /*indexMap=*/AffineMap::get(indexingMap.getNumDims(), + indexingMap.getNumSymbols(), newIndexExprs, + context), + /*reassociation=*/reassociation, /*targetShape=*/newShape}; return info; } namespace { -SmallVector -convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) { - SmallVector reassociationExprs; - for (auto attr : affineMapArrayAttr) - reassociationExprs.push_back( - llvm::to_vector<4>(attr.cast().getValue().getResults())); - return reassociationExprs; -} - /// Pattern to replace tensor/buffer operands/results that are unit extents. struct ReplaceUnitExtents : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - // Return the original value if the type is unchanged, or reshape it. Return a - // nullptr if this is an unsupported type. - Value maybeExpand(Value result, Type origResultType, - ArrayAttr reassociationMap, Location loc, + ReplaceUnitExtents(MLIRContext *ctx, + RankReductionStrategy rankReductionStrategy) + : OpRewritePattern(ctx), + rankReductionStrategy(rankReductionStrategy) {} + + // Expand the given value. + Value expandValue(Value result, Value origOutput, + ArrayRef reassociation, Location loc, PatternRewriter &rewriter) const { - if (origResultType == result.getType()) - return result; - if (origResultType.isa()) { - return rewriter.create( - loc, origResultType, result, - convertAffineMapArrayToExprs(reassociationMap)); - } - if (origResultType.isa()) { - return rewriter.create( - loc, origResultType, result, - convertAffineMapArrayToExprs(reassociationMap)); + // There are no results for memref outputs. + auto origResultType = origOutput.getType().cast(); + if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { + unsigned rank = origResultType.getRank(); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector sizes = + tensor::getMixedSizes(rewriter, loc, origOutput); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + return rewriter.createOrFold( + loc, result, origOutput, offsets, sizes, strides); } - return nullptr; - }; - // Return the original value if the type is unchanged, or reshape it. Return a - // nullptr if this is an unsupported type. - Value maybeCollapse(Value operand, Type newInputOutputType, - ArrayAttr reassociationMap, Location loc, - PatternRewriter &rewriter) const { - auto operandType = operand.getType(); - if (operandType == newInputOutputType) - return operand; - if (operandType.isa()) { - return rewriter.create( - loc, newInputOutputType, operand, - convertAffineMapArrayToExprs(reassociationMap)); + assert(rankReductionStrategy == + RankReductionStrategy::ReassociativeReshape && + "unknown rank reduction strategy"); + return rewriter.create(loc, origResultType, result, + reassociation); + } + + // Collapse the given value. + Value collapseValue(Value operand, ArrayRef targetShape, + ArrayRef reassociation, + Location loc, PatternRewriter &rewriter) const { + if (auto memrefType = operand.getType().dyn_cast()) { + if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { + FailureOr rankReducingExtract = + memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, + targetShape); + assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); + return *rankReducingExtract; + } + + assert(rankReductionStrategy == + RankReductionStrategy::ReassociativeReshape && + "unknown rank reduction strategy"); + MemRefLayoutAttrInterface layout; + auto targetType = + MemRefType::get(targetShape, memrefType.getElementType(), layout, + memrefType.getMemorySpace()); + return rewriter.create(loc, targetType, operand, + reassociation); } - if (operandType.isa()) { - return rewriter.create( - loc, newInputOutputType, operand, - convertAffineMapArrayToExprs(reassociationMap)); + if (auto tensorType = operand.getType().dyn_cast()) { + if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { + FailureOr rankReducingExtract = + tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, + targetShape); + assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); + return *rankReducingExtract; + } + + assert(rankReductionStrategy == + RankReductionStrategy::ReassociativeReshape && + "unknown rank reduction strategy"); + auto targetType = + RankedTensorType::get(targetShape, tensorType.getElementType()); + return rewriter.create(loc, targetType, operand, + reassociation); } - return nullptr; - }; + llvm_unreachable("unsupported operand type"); + } LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { @@ -495,71 +508,59 @@ return failure(); MLIRContext *context = rewriter.getContext(); Location loc = genericOp.getLoc(); + SmallVector oldOutputs(genericOp.getOutputs().begin(), + genericOp.getOutputs().end()); SmallVector newIndexingMaps; - SmallVector reassociationMaps; - SmallVector newInputOutputTypes; - bool doCanonicalization = false; + SmallVector> reassociations; + SmallVector> targetShapes; + SmallVector collapsed; for (OpOperand &opOperand : genericOp->getOpOperands()) { auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context); if (replacementInfo) { - reassociationMaps.push_back(replacementInfo->reassociation); + reassociations.push_back(replacementInfo->reassociation); newIndexingMaps.push_back(replacementInfo->indexMap); - newInputOutputTypes.push_back(replacementInfo->type); - doCanonicalization |= - replacementInfo->type != opOperand.get().getType(); + targetShapes.push_back(replacementInfo->targetShape); + collapsed.push_back(true); } else { - // If replaceUnitExtents cannot handle this case, maintain the same - // type, indexing map, and create a set of mappings representing an - // identity matrix. - newInputOutputTypes.push_back(opOperand.get().getType()); + // If replaceUnitExtents cannot handle this case (or no unit dim was + // removed), maintain the same type, indexing map, and create a set of + // mappings representing an identity matrix. newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand)); - int64_t origRank = genericOp.getRank(&opOperand); - auto maps = llvm::to_vector<8>(llvm::map_range( - llvm::seq(0, origRank), [&](int64_t dim) -> Attribute { - return AffineMapAttr::get( - AffineMap::get(origRank, /*symbolCount = */ 0, - getAffineDimExpr(dim, context), context)); - })); - reassociationMaps.push_back(ArrayAttr::get(context, maps)); + reassociations.emplace_back(); + targetShapes.emplace_back(); + collapsed.push_back(false); } } - // If the indexing maps of the result operation are not invertible (i.e. not - // legal), abort. - if (!doCanonicalization || + // Abort if the indexing maps of the result operation are not invertible + // (i.e. not legal) or if no dimension was reduced. + if (!llvm::any_of(collapsed, [](bool c) { return c; }) || !inversePermutation(concatAffineMaps(newIndexingMaps))) return failure(); - // If any operand type change, insert a reshape to convert from the original - // type to the new type. - // TODO: get rid of flattenedIdx which assumes operand order and contiguity. - unsigned flattenedIdx = 0; - auto insertReshapes = [&](ValueRange values) { - SmallVector res; - res.reserve(values.size()); - for (auto operand : values) { - auto reshapedValue = - maybeCollapse(operand, newInputOutputTypes[flattenedIdx], - reassociationMaps[flattenedIdx], loc, rewriter); - assert(reshapedValue && - "expected ranked MemRef or Tensor operand type"); - res.push_back(reshapedValue); - ++flattenedIdx; + // Insert rank reductions. + SmallVector newOperands; + for (OpOperand &opOperand : genericOp->getOpOperands()) { + int64_t idx = opOperand.getOperandNumber(); + if (!collapsed[idx]) { + newOperands.push_back(opOperand.get()); + continue; } - return res; - }; - - SmallVector newInputs = insertReshapes(genericOp.getInputs()); - SmallVector newOutputs = insertReshapes(genericOp.getOutputs()); + newOperands.push_back(collapseValue(opOperand.get(), targetShapes[idx], + reassociations[idx], loc, rewriter)); + } // If any result type changes, insert a reshape to convert from the original // type to the new type. - SmallVector resultTypes; + ArrayRef newInputs = + ArrayRef(newOperands).take_front(genericOp.getNumDpsInputs()); + ArrayRef newOutputs = + ArrayRef(newOperands).take_back(genericOp.getNumDpsInits()); + SmallVector resultTypes; resultTypes.reserve(genericOp.getNumResults()); for (unsigned i : llvm::seq(0, genericOp.getNumResults())) - resultTypes.push_back( - newInputOutputTypes[i + genericOp.getNumDpsInputs()]); + resultTypes.push_back(newOutputs[i].getType()); GenericOp replacementOp = rewriter.create( loc, resultTypes, newInputs, newOutputs, newIndexingMaps, genericOp.getIteratorTypesArray()); @@ -569,20 +570,24 @@ // If any result tensor has a modified shape, then add reshape to recover // the original shape. - SmallVector resultReplacements; + SmallVector resultReplacements; for (const auto &result : llvm::enumerate(replacementOp.getResults())) { unsigned index = result.index() + replacementOp.getNumDpsInputs(); - auto origResultType = genericOp.getResult(result.index()).getType(); - - auto newResult = maybeExpand(result.value(), origResultType, - reassociationMaps[index], loc, rewriter); - assert(newResult && - "unexpected output type other than ranked MemRef or Tensor"); - resultReplacements.push_back(newResult); + Value origOutput = oldOutputs[result.index()]; + if (!collapsed[result.index() + genericOp.getNumDpsInputs()]) { + resultReplacements.push_back(result.value()); + continue; + } + resultReplacements.push_back(expandValue( + result.value(), origOutput, reassociations[index], loc, rewriter)); } + rewriter.replaceOp(genericOp, resultReplacements); return success(); } + +private: + RankReductionStrategy rankReductionStrategy; }; } // namespace @@ -656,14 +661,16 @@ /// Patterns that are used to canonicalize the use of unit-extent dims for /// broadcasting. -void mlir::linalg::populateFoldUnitExtentDimsPatterns( +void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add, - RankReducedInsertSliceOp>( - context); + patterns.add(context, + RankReductionStrategy::ReassociativeReshape); + // TODO: Patterns unrelated to unit dim folding should be factored out. + patterns + .add, + RankReducedInsertSliceOp>(context); linalg::FillOp::getCanonicalizationPatterns(patterns, context); tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); @@ -673,6 +680,14 @@ memref::populateResolveShapedTypeResultDimsPatterns(patterns); } +void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns( + RewritePatternSet &patterns) { + auto *context = patterns.getContext(); + patterns.add(context, + RankReductionStrategy::ExtractInsertSlice); + patterns.add(context); +} + namespace { /// Pass that removes unit-extent dims within generic ops. struct LinalgFoldUnitExtentDimsPass @@ -681,10 +696,13 @@ Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); - if (foldOneTripLoopsOnly) + if (foldOneTripLoopsOnly) { patterns.add(context); - else - populateFoldUnitExtentDimsPatterns(patterns); + } else if (useRankReducingSlices) { + populateFoldUnitExtentDimsViaSlicesPatterns(patterns); + } else { + populateFoldUnitExtentDimsViaReshapesPatterns(patterns); + } (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -109,6 +109,21 @@ return NoneType::get(type.getContext()); } +SmallVector memref::getMixedSizes(OpBuilder &builder, + Location loc, Value value) { + auto memrefType = value.getType().cast(); + SmallVector result; + for (int64_t i = 0; i < memrefType.getRank(); ++i) { + if (memrefType.isDynamicDim(i)) { + Value size = builder.create(loc, value, i); + result.push_back(size); + } else { + result.push_back(builder.getIndexAttr(memrefType.getDimSize(i))); + } + } + return result; +} + //===----------------------------------------------------------------------===// // Utility functions for propagating static information //===----------------------------------------------------------------------===// @@ -2912,6 +2927,35 @@ mixedStrides); } +Value mlir::memref::createCanonicalRankReducingSubViewOp( + OpBuilder &b, Location loc, Value memref, ArrayRef targetShape) { + auto memrefType = memref.getType().cast(); + unsigned rank = memrefType.getRank(); + SmallVector offsets(rank, b.getIndexAttr(0)); + SmallVector sizes = getMixedSizes(b, loc, memref); + SmallVector strides(rank, b.getIndexAttr(1)); + auto targetType = SubViewOp::inferRankReducedResultType( + targetShape, memrefType, offsets, sizes, strides) + .cast(); + return b.createOrFold(loc, targetType, memref, offsets, + sizes, strides); +} + +FailureOr SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc, + Value value, + ArrayRef desiredShape) { + auto sourceMemrefType = value.getType().dyn_cast(); + assert(sourceMemrefType && "not a ranked memref type"); + auto sourceShape = sourceMemrefType.getShape(); + if (sourceShape.equals(desiredShape)) + return value; + auto maybeRankReductionMask = + mlir::computeRankReductionMask(sourceShape, desiredShape); + if (!maybeRankReductionMask) + return failure(); + return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape); +} + /// Helper method to check if a `subview` operation is trivially a no-op. This /// is the case if the all offsets are zero, all strides are 1, and the source /// shape is same as the size of the subview. In such cases, the subview can diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(linalg-fold-unit-extent-dims))" | FileCheck %s +// RUN: mlir-opt %s -linalg-fold-unit-extent-dims -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-fold-unit-extent-dims="use-rank-reducing-slices" -cse -split-input-file | FileCheck %s --check-prefix=CHECK-SLICES #accesses = [ affine_map<(i, j, k, l, m) -> (i, k, m)>, @@ -26,11 +27,57 @@ // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @drop_one_trip_loops // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] +// CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> +// CHECK-SLICES-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-SLICES-LABEL: func @drop_one_trip_loops +// CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0] [%{{.*}}, 1, %{{.*}}] [1, 1, 1] : tensor to tensor +// CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1, 1] : tensor to tensor +// CHECK-SLICES: linalg.generic +// CHECK-SLICES-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] +// CHECK-SLICES-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SLICES: tensor.insert_slice %{{.*}} into %{{.*}}[0, 0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1, 1] : tensor into tensor + + +// ----- + +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> ()>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, %shape: tensor) -> tensor { + %0 = linalg.generic #trait + ins(%arg0, %arg1 : tensor<1x1x1xf32>, f32) + outs(%shape : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) : + linalg.yield %arg3 : f32 + } -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, d0, 0)> +// CHECK-LABEL: func @drop_one_trip_loops_all_ones +// CHECK: tensor.collapse_shape %{{.*}} [] +// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] + // ----- #accesses = [ @@ -871,3 +918,8 @@ // CHECK: memref.collapse_shape // CHECK-SAME: [] : memref<1x1xf32, 3> into memref // CHECK: linalg.generic{{.*}}memref + +// CHECK-SLICES-LABEL: func @drop_all_loops +// CHECK-SLICES: memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref, 3> +// CHECK-SLICES: linalg.generic{{.*}}memref, 3> +