diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -335,7 +335,9 @@ AffineMap indexMap; SmallVector reassociation; SmallVector targetShape; + SmallVector collapsedDims; }; + static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, llvm::SmallDenseMap &oldDimsToNewDimsMap, @@ -362,8 +364,10 @@ }; unsigned dim = 0; - while (dim < operandShape.size() && isUnitDim(dim)) + while (dim < operandShape.size() && isUnitDim(dim)) { + info.collapsedDims.push_back(dim); reassociationGroup.push_back(dim++); + } while (dim < operandShape.size()) { assert(!isUnitDim(dim) && "expected non unit-extent"); reassociationGroup.push_back(dim); @@ -373,6 +377,7 @@ ++dim; // Fold all following dimensions that are unit-extent. while (dim < operandShape.size() && isUnitDim(dim)) { + info.collapsedDims.push_back(dim); reassociationGroup.push_back(dim++); } info.reassociation.push_back(reassociationGroup); @@ -384,6 +389,24 @@ return info; } +/// Cast the given shaped value to the specified shape. +static Value castToShape(OpBuilder &b, Value shapedValue, + ArrayRef shape) { + if (auto tensorType = dyn_cast(shapedValue.getType())) { + auto targetType = RankedTensorType::get(shape, tensorType.getElementType(), + tensorType.getEncoding()); + return b.create(shapedValue.getLoc(), targetType, + shapedValue); + } + + auto memrefType = cast(shapedValue.getType()); + MemRefLayoutAttrInterface layout; + auto targetType = MemRefType::get(shape, memrefType.getElementType(), layout, + memrefType.getMemorySpace()); + return b.create(shapedValue.getLoc(), targetType, + shapedValue); +} + LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options) { SmallVector indexingMaps = genericOp.getIndexingMapsArray(); @@ -451,7 +474,7 @@ SmallVector newIndexingMaps; SmallVector> reassociations; SmallVector> targetShapes; - SmallVector collapsed; + SmallVector> collapsedDims; auto hasCollapsibleType = [](OpOperand &operand) { Type operandType = operand.get().getType(); if (auto memrefOperandType = dyn_cast_or_null(operandType)) { @@ -470,7 +493,7 @@ dimReplacements, ArrayRef{}, oldDimToNewDimMap.size(), 0); newIndexingMaps.push_back(newIndexingMap); targetShapes.push_back(llvm::to_vector(shape)); - collapsed.push_back(false); + collapsedDims.push_back({}); reassociations.push_back({}); continue; } @@ -480,8 +503,7 @@ reassociations.push_back(replacementInfo.reassociation); newIndexingMaps.push_back(replacementInfo.indexMap); targetShapes.push_back(replacementInfo.targetShape); - collapsed.push_back(!(replacementInfo.indexMap.getNumResults() == - indexingMap.getNumResults())); + collapsedDims.push_back(replacementInfo.collapsedDims); } // Abort if the indexing maps of the result operation are not invertible @@ -490,19 +512,40 @@ !inversePermutation(concatAffineMaps(newIndexingMaps))) return failure(); - Location loc = genericOp.getLoc(); // 4. For each of the operands, collapse the operand to convert // from original shape to shape in the modified operation if needed, // either through use of reshapes or rank-reducing slices as // specified in `options`. + Location loc = genericOp.getLoc(); SmallVector newOperands; for (OpOperand &opOperand : genericOp->getOpOperands()) { int64_t idx = opOperand.getOperandNumber(); - if (!collapsed[idx]) { + + // Check if there are any dims to collapse. + if (collapsedDims[idx].empty()) { newOperands.push_back(opOperand.get()); continue; } - newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(), + + // Look for operand dims that should be collapsed but have a dynamic size. + // Cast such dims to static `1`. + SmallVector operandShape( + cast(opOperand.get().getType()).getShape()); + bool castNeeded = false; + for (int64_t i = 0; i < collapsedDims[idx].size(); ++i) { + auto dim = collapsedDims[idx][i]; + if (operandShape[dim] != 1) { + assert(ShapedType::isDynamic(operandShape[dim]) && + "expected dynamic dim size"); + operandShape[dim] = 1; + castNeeded = true; + } + } + + Value source = castNeeded + ? castToShape(rewriter, opOperand.get(), operandShape) + : opOperand.get(); + newOperands.push_back(collapseValue(rewriter, loc, source, targetShapes[idx], reassociations[idx], options.rankReductionStrategy)); } @@ -527,13 +570,12 @@ replaceUnitDimIndexOps(replacementOp, unitDims, rewriter); // 6. If any result type changes, insert a reshape/slice to convert from the - // original - // type to the new type. + // original type to the new type. SmallVector resultReplacements; for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) { unsigned opOperandIndex = index + replacementOp.getNumDpsInputs(); Value origDest = genericOp.getDpsInitOperand(index)->get(); - if (!collapsed[opOperandIndex]) { + if (collapsedDims[opOperandIndex].empty()) { resultReplacements.push_back(result); continue; } diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -922,3 +922,31 @@ // CHECK-SLICES-LABEL: func @drop_all_loops // CHECK-SLICES: memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref, 3> // CHECK-SLICES: linalg.generic{{.*}}memref, 3> + +// ----- + +#map4 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map5 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +func.func @dynamic_dim_not_collapsible(%arg0: memref, %arg1: tensor<1x28xf16>) { + linalg.generic {indexing_maps = [#map4, #map5], iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg1 : tensor<1x28xf16>) outs(%arg0 : memref) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } + return +} + +// CHECK-LABEL: func @dynamic_dim_not_collapsible( +// CHECK-SAME: %[[arg0:.*]]: memref, %[[arg1:.*]]: tensor<1x28xf16> +// CHECK: %[[c1:.*]] = tensor.collapse_shape %[[arg1]] {{\[}}[0, 1]] : tensor<1x28xf16> into tensor<28xf16> +// CHECK: %[[cast:.*]] = memref.cast %[[arg0]] : memref to memref<1x28x4xf16> +// CHECK: %[[c2:.*]] = memref.collapse_shape %[[cast]] {{\[}}[0, 1], [2]] : memref<1x28x4xf16> into memref<28x4xf16> +// CHECK: linalg.generic {{.*}} ins(%[[c1]] : tensor<28xf16>) outs(%[[c2]] : memref<28x4xf16>) + +// CHECK-SLICES-LABEL: func @dynamic_dim_not_collapsible( +// CHECK-SLICES-SAME: %[[arg0:.*]]: memref, %[[arg1:.*]]: tensor<1x28xf16> +// CHECK-SLICES: %[[c1:.*]] = tensor.extract_slice %[[arg1]][0, 0] [1, 28] [1, 1] : tensor<1x28xf16> to tensor<28xf16> +// CHECK-SLICES: %[[cast:.*]] = memref.cast %[[arg0]] : memref to memref<1x28x4xf16> +// CHECK-SLICES: %[[c2:.*]] = memref.subview %[[cast]][0, 0, 0] [1, 28, 4] [1, 1, 1] : memref<1x28x4xf16> to memref<28x4xf16, strided<[4, 1]>> +// CHECK-SLICES: linalg.generic {{.*}} ins(%[[c1]] : tensor<28xf16>) outs(%[[c2]] : memref<28x4xf16, strided<[4, 1]>>)