diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -263,6 +263,9 @@ SmallVector getMixedHighPad() { return getMixedPadImpl(static_high(), high()); } + + // Return the pad value if it is a constant. Return null value otherwise. + Value getConstantPaddingValue(); }]; let builders = [ diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -104,8 +104,8 @@ /// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op /// before the consumer Linalg op, until enough canonicalizations have applied. struct FusionInfo { - LinalgOp originalProducer; - LinalgOp fusedProducer; + Operation *originalProducer; + Operation *fusedProducer; }; /// Fuses producer into consumer if the producer is structurally feasible and diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1011,7 +1011,7 @@ ArrayRef attrs) { auto sourceType = source.getType().cast(); unsigned rank = sourceType.getRank(); - SmallVector staticVector(ShapedType::kDynamicSize, rank); + SmallVector staticVector(rank, ShapedType::kDynamicSize); build(b, result, source, staticVector, staticVector, low, high, attrs); } @@ -1109,6 +1109,18 @@ return success(); } +Value PadTensorOp::getConstantPaddingValue() { + auto yieldOp = dyn_cast(getRegion().front().getTerminator()); + if (!yieldOp || yieldOp.values().size() != 1) + return {}; + + Value padValue = yieldOp.values().front(); + if (matchPattern(padValue, m_Constant())) + return padValue; + + return {}; +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" @@ -185,8 +186,8 @@ tileSizes.push_back(it->second.size); sizeBounds.push_back(nullptr); loopRanges.push_back(it->second); - LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange " - << loopRanges.back() << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "tiled loop#" << i << loopRanges.back() << "\n"); } else { auto shapeDim = getShapeDefiningLoopRange(producer, i); Value dim = b.createOrFold(loc, shapeDim.shape, @@ -194,8 +195,8 @@ tileSizes.push_back(zero); sizeBounds.push_back(dim); loopRanges.push_back(Range{zero, dim, one}); - LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange " - << loopRanges.back() << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "full loop#" << i << loopRanges.back() << "\n"); } } @@ -290,6 +291,166 @@ return fuse(b, producerOp, fusedLoopsAndRanges); } +/// Fuses the linalg.pad_tensor `producerOp` into the loops immediately +/// enclosing the given consumer that uses the `producerOp` with +/// `consumerOpOperand`. The fusion is done by creating a subtensor for the +/// source tensor and create a new linalg.pad_tensor for it. +static PadTensorOp fusePadTensor(OpBuilder &builder, PadTensorOp producerOp, + OpOperand &consumerOpOperand) { + Value cstPadValue = producerOp.getConstantPaddingValue(); + // Only support constant padding values for now. + if (!cstPadValue) + return {}; + + MLIRContext *context = builder.getContext(); + Location loc = consumerOpOperand.getOwner()->getLoc(); + Value shapedOperand = consumerOpOperand.get(); + int64_t numLoops = producerOp.getResultType().getRank(); + + // linalg.pad_tensor cannot expand or shrink the number of dimensions. So we + // can get the all loops' ranges from its consumer. + SmallVector loopRanges; + for (int i = 0; i < numLoops; ++i) { + Range range = getRangeFromOperandShape(builder, loc, shapedOperand, i); + LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << range << "\n"); + loopRanges.push_back(range); + + // Only support stride 1 right now. + IntegerAttr one; + if (!matchPattern(range.stride, m_Constant(&one)) || + one.getValue().getZExtValue() != 1) + return {}; + } + + auto getValuePaddings = [&](ArrayRef paddings) { + SmallVector valuePaddings; + for (OpFoldResult pad : paddings) { + if (Value value = pad.dyn_cast()) { + valuePaddings.push_back(value); + } else { + auto cst = pad.get().cast().getInt(); + valuePaddings.push_back(builder.create(loc, cst)); + } + } + return valuePaddings; + }; + + SmallVector lowPaddings = + getValuePaddings(producerOp.getMixedLowPad()); + SmallVector highPaddings = + getValuePaddings(producerOp.getMixedHighPad()); + + AffineExpr dim0, dim1, dim2; + bindDims(context, dim0, dim1, dim2); + auto idMap = AffineMap::getMultiDimIdentityMap(2, context); + auto add2Map = AffineMap::get(3, 0, {dim0 + dim1 + dim2}); + auto sub1Map = AffineMap::get(2, 0, {dim0 - dim1}); + auto sub2Map = AffineMap::get(3, 0, {dim0 - dim1 - dim2}); + + SmallVector newSrcOffsets(numLoops); + SmallVector newSrcSizes(numLoops); + SmallVector newSrcStrides(numLoops); + SmallVector newLowPaddings(numLoops); + SmallVector newHighPaddings(numLoops); + + for (int i = 0; i < numLoops; ++i) { + const auto &range = loopRanges[i]; + + // Get the original and padded dimension size. + Value originalDim = + builder.create(loc, producerOp.source(), i); + Value paddedDim = builder.create( + loc, add2Map, ValueRange{originalDim, lowPaddings[i], highPaddings[i]}); + + // Calculate the low padding amount that is really used in this tile. We + // first calculate how many elements aren't used, then subtract it from the + // total low padding amount. Affine min ops are needed to make sure in + // bounds. + // + // For the middle tile in three tiles, there are three cases (where `L` is + // the first element after low paddings: lowPaddings[i]): + // + // ||---L----||--------||--------|| + // unusedLowPadding = lowPaddings[i] + // potentialUsedLowPadding = 0 + // usedLowPadding = 0 + // newSrcOffset = range.offset - lowPaddings[i] + // + // ||--------||---L----||--------|| + // unusedLowPadding = range.offset + // potentialUsedLowPadding = lowPaddings[i] - range.offset + // usedLowPadding = lowPaddings[i] - range.offset + // newSrcOffset = 0 + // + // ||--------||--------||---L----|| + // unusedLowPadding = range.offset + // potentialUsedLowPadding = lowPaddings[i] - range.offset + // usedLowPadding = range.size + // newSrcOffset = 0 + auto unusedLowPadding = builder.create( + loc, idMap, ValueRange{lowPaddings[i], range.offset}); + auto potentialUsedLowPadding = builder.create( + loc, sub1Map, ValueRange{lowPaddings[i], unusedLowPadding}); + auto usedLowPadding = builder.create( + loc, idMap, ValueRange{range.size, potentialUsedLowPadding}); + + // Similarly for high padding. We need to use the remaining elements to + // compare with the total high padding amount to deduce the unused amount + // here. + // + // For the middle tile in three tiles, there are three cases (where `H` is + // the last element before high paddings: paddedDim - highPaddings[i]): + // + // ||--------||--------||---H----|| + // unusedHighPadding = highPaddings[i] + // potentialUsedHighPadding = 0 + // usedHighPadding = 0 + // + // ||--------||---H----||--------|| + // unusedHighPadding = paddedDim - range.offset - range.size + // potentialUsedHighPadding = highPaddings[i] - unusedHighPadding + // usedHighPadding = potentialUnusedHighPadding + // + // ||---H----||--------||--------|| + // unusedHighPadding = paddedDim - range.offset - range.size + // potentialUsedHighPadding = highPaddings[i] - unusedHighPadding + // usedHighPadding = range.size + auto remainingElements = builder.create( + loc, sub2Map, ValueRange{paddedDim, range.offset, range.size}); + auto unusedHighPadding = builder.create( + loc, idMap, ValueRange{highPaddings[i], remainingElements}); + auto potentialUsedHighPadding = builder.create( + loc, sub1Map, ValueRange{highPaddings[i], unusedHighPadding}); + auto usedHighPadding = builder.create( + loc, idMap, ValueRange{range.size, potentialUsedHighPadding}); + + newSrcOffsets[i] = builder.create( + loc, sub1Map, ValueRange{range.offset, unusedLowPadding}); + newSrcSizes[i] = builder.create( + loc, sub2Map, ValueRange{range.size, usedLowPadding, usedHighPadding}); + newSrcStrides[i] = range.stride; + + newLowPaddings[i] = usedLowPadding; + newHighPaddings[i] = usedHighPadding; + } + + // Create the subtensor for the source tensor and pad it. + auto srcSubTensor = builder.create( + loc, producerOp.source(), newSrcOffsets, newSrcSizes, newSrcStrides); + auto padSubTensor = builder.create( + loc, srcSubTensor, newLowPaddings, newHighPaddings); + + // Create the region for the new linalg.pad_tensor op. + OpBuilder::InsertionGuard guard(builder); + auto ®ion = padSubTensor.region(); + SmallVector blockArgTypes; + blockArgTypes.assign(numLoops, builder.getIndexType()); + builder.createBlock(®ion, region.end(), blockArgTypes); + builder.create(loc, cstPadValue); + + return padSubTensor; +} + // Encode structural fusion safety preconditions. // Some of these will be lifted in the future with better analysis. static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, @@ -483,11 +644,15 @@ return; while (true) { - LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); + LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor << "\n"); if (auto linalgOp = tensor.getDefiningOp()) { opResult = tensor.cast(); return; } + if (auto padOp = tensor.getDefiningOp()) { + opResult = tensor.cast(); + return; + } if (auto subTensorOp = tensor.getDefiningOp()) { tensor = subTensorOp.source(); continue; @@ -508,7 +673,7 @@ OpResult producerOpResult; getProducerOfTensor(inputTensor, producerOpResult); if (!producerOpResult) { - LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer"); + LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer\n"); return {}; } return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); @@ -521,13 +686,11 @@ if (isa(producerOpResult.getOwner())) return llvm::None; - auto producerOp = dyn_cast(producerOpResult.getOwner()); - if (!producerOp) - return llvm::None; - LinalgOp consumerOp = dyn_cast(consumerOpOperand.getOwner()); - if (!consumerOp) + if (!consumerOp) { + LLVM_DEBUG(llvm::dbgs() << "cannot fuse: consumer is not a linalg op\n"); return llvm::None; + } Value inputTensor = consumerOpOperand.get(); @@ -548,10 +711,24 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(consumerOp); LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); - LinalgOp fusedProducer = - fuse(b, producerOp, - producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()), - consumerOpOperand); + + Operation *fusedProducer = nullptr; + if (auto producerOp = dyn_cast(producerOpResult.getOwner())) { + fusedProducer = fuse( + b, producerOp, + producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()), + consumerOpOperand); + } else if (auto producerOp = + dyn_cast(producerOpResult.getOwner())) { + fusedProducer = fusePadTensor(b, producerOp, consumerOpOperand); + if (!fusedProducer) { + LLVM_DEBUG(llvm::dbgs() << "failed to fuse linalg.pad_tensor\n"); + return llvm::None; + } + } else { + LLVM_DEBUG(llvm::dbgs() << "cannot fuse: producer is not a linalg op\n"); + return llvm::None; + } // Replace use. // Canonicalizations are not guaranteed to have happened before constructing @@ -561,9 +738,9 @@ Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); Type consumerType = consumerOpOperand.get().getType(); if (consumerType != def.getType()) - def = b.create(fusedProducer.getLoc(), consumerType, def); + def = b.create(fusedProducer->getLoc(), consumerType, def); consumerOpOperand.set(def); - return FusionInfo{cast(producerOpResult.getOwner()), fusedProducer}; + return FusionInfo{producerOpResult.getOwner(), fusedProducer}; } /// Prune all dimensions that are of reduction iterator type from `map`. diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -277,3 +277,205 @@ // CHECK-SAME: outs(%[[ST_ARG]] : tensor) // CHECK: subtensor_insert %[[ST_ADD]] into %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]] + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> + +func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tensor<64x128xf32>) -> tensor<64x128xf32> { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c16 = constant 16 : index + %c32 = constant 32 : index + %zero = constant 0.0 : f32 + + %d0 = memref.dim %large_input, %c0 : tensor<64x128xf32> + %d1 = memref.dim %large_input, %c1 : tensor<64x128xf32> + + %pad = linalg.pad_tensor %small_input low[4, 60] high[2, 67] { + ^bb0(%arg0: index, %arg1: index): + linalg.yield %zero : f32 + } : tensor<58x1xf32> to tensor<64x128xf32> + + %fill = linalg.fill(%large_input, %zero) : tensor<64x128xf32>, f32 -> tensor<64x128xf32> + + %for0 = scf.for %iv0 = %c0 to %d0 step %c16 iter_args(%arg0 = %fill) -> tensor<64x128xf32> { + %for1 = scf.for %iv1 = %c0 to %d1 step %c32 iter_args(%arg1 = %arg0) -> tensor<64x128xf32> { + %0 = subtensor %pad[%iv0, %iv1][16, 32][1, 1] : tensor<64x128xf32> to tensor<16x32xf32> + %1 = subtensor %large_input[%iv0, %iv1][16, 32][1, 1] : tensor<64x128xf32> to tensor<16x32xf32> + %2 = subtensor %arg1[%iv0, %iv1][16, 32][1, 1] : tensor<64x128xf32> to tensor<16x32xf32> + + %add = linalg.generic + {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} + ins(%0, %1 : tensor<16x32xf32>, tensor<16x32xf32>) outs(%2 : tensor<16x32xf32>) { + ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): + %result = addf %arg4, %arg5 : f32 + linalg.yield %result : f32 + } -> tensor<16x32xf32> + + %insert = subtensor_insert %add into %arg1[%iv0, %iv1] [16, 32] [1, 1] : tensor<16x32xf32> into tensor<64x128xf32> + scf.yield %insert : tensor<64x128xf32> + } + scf.yield %for1 : tensor<64x128xf32> + } + return %for0 : tensor<64x128xf32> +} + +// CHECK-DAG: #[[DIM0_LOW_UNUSED_MAP:.+]] = affine_map<(d0) -> (4, d0)> +// CHECK-DAG: #[[DIM0_LOW_USED_MAP:.+]] = affine_map<(d0) -> (16, -d0 + 4)> +// CHECK-DAG: #[[DIM0_HIGH_UNUSED_MAP:.+]] = affine_map<(d0) -> (2, -d0 + 48)> +// CHECK-DAG: #[[DIM0_HIGH_USED_MAP:.+]] = affine_map<(d0) -> (16, -d0 + 2)> +// CHECK-DAG: #[[SUB_MAP:.+]] = affine_map<(d0, d1) -> (d0 - d1)> +// CHECK-DAG: #[[DIM0_SIZE_MAP:.+]] = affine_map<(d0, d1) -> (-d0 - d1 + 16)> +// CHECK-DAG: #[[DIM1_LOW_UNUSED_MAP:.+]] = affine_map<(d0) -> (60, d0)> +// CHECK-DAG: #[[DIM1_LOW_USED_MAP:.+]] = affine_map<(d0) -> (32, -d0 + 60)> +// CHECK-DAG: #[[DIM1_HIGH_UNUSED_MAP:.+]] = affine_map<(d0) -> (67, -d0 + 96)> +// CHECK-DAG: #[[DIM1_HIGH_USED_MAP:.+]] = affine_map<(d0) -> (32, -d0 + 67)> +// CHECK-DAG: #[[DIM1_SIZE_MAP:.+]] = affine_map<(d0, d1) -> (-d0 - d1 + 32)> + +// CHECK: func @pad_generic_static +// CHECK-SAME: %[[SMALL_INPUT:.+]]: tensor<58x1xf32> + +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: scf.for %[[IV0:[a-z0-9]+]] +// CHECK: %[[DIM0_UNUSED_LOW_PAD:.+]] = affine.min #[[DIM0_LOW_UNUSED_MAP]](%[[IV0]]) +// CHECK: %[[DIM0_USED_LOW_PAD:.+]] = affine.min #[[DIM0_LOW_USED_MAP]](%[[DIM0_UNUSED_LOW_PAD]]) +// CHECK: %[[DIM0_UNUSED_HIGH_PAD:.+]] = affine.min #[[DIM0_HIGH_UNUSED_MAP]](%[[IV0]]) +// CHECK: %[[DIM0_USED_HIGH_PAD:.+]] = affine.min #[[DIM0_HIGH_USED_MAP]](%[[DIM0_UNUSED_HIGH_PAD]]) +// CHECK: %[[DIM0_SRC_OFFSET:.+]] = affine.apply #[[SUB_MAP]](%[[IV0]], %[[DIM0_UNUSED_LOW_PAD]]) +// CHECK: %[[DIM0_SRC_SIZE:.+]] = affine.apply #[[DIM0_SIZE_MAP]](%[[DIM0_USED_LOW_PAD]], %[[DIM0_USED_HIGH_PAD]]) +// CHECK: scf.for %[[IV1:[a-z0-9]+]] +// CHECK: %[[DIM1_UNUSED_LOW_PAD:.+]] = affine.min #[[DIM1_LOW_UNUSED_MAP]](%[[IV1]]) +// CHECK: %[[DIM1_USED_LOW_PAD:.+]] = affine.min #[[DIM1_LOW_USED_MAP]](%[[DIM1_UNUSED_LOW_PAD]]) +// CHECK: %[[DIM1_UNUSED_HIGH_PAD:.+]] = affine.min #[[DIM1_HIGH_UNUSED_MAP]](%[[IV1]]) +// CHECK: %[[DIM1_USED_HIGH_PAD:.+]] = affine.min #[[DIM1_HIGH_USED_MAP]](%[[DIM1_UNUSED_HIGH_PAD]]) +// CHECK: %[[DIM1_SRC_OFFSET:.+]] = affine.apply #[[SUB_MAP]](%[[IV1]], %[[DIM1_UNUSED_LOW_PAD]]) +// CHECK: %[[DIM1_SRC_SIZE:.+]] = affine.apply #[[DIM1_SIZE_MAP]](%[[DIM1_USED_LOW_PAD]], %[[DIM1_USED_HIGH_PAD]]) +// CHECK: %[[SRC_SUBTENSOR:.+]] = subtensor %[[SMALL_INPUT]] +// CHECK-SAME: [%[[DIM0_SRC_OFFSET]], %[[DIM1_SRC_OFFSET]]] +// CHECK-SAME: [%[[DIM0_SRC_SIZE]], %[[DIM1_SRC_SIZE]]] +// CHECK: %[[PAD:.+]] = linalg.pad_tensor %[[SRC_SUBTENSOR]] +// CHECK-SAME: low[%[[DIM0_USED_LOW_PAD]], %[[DIM1_USED_LOW_PAD]]] +// CHECK-SAME: high[%[[DIM0_USED_HIGH_PAD]], %[[DIM1_USED_HIGH_PAD]]] +// CHECK: linalg.yield %[[ZERO]] +// CHECK: %[[CAST:.+]] = tensor.cast %[[PAD]] : tensor to tensor<16x32xf32> +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[CAST]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +func @pad_generic_dynamic(%small_input: tensor, %large_input: tensor, %input_low_pad: index, %input_high_pad: index) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c8 = constant 8 : index + %zero = constant 0.0 : f32 + + %d0 = memref.dim %large_input, %c0 : tensor // not to pad, to tile + %d1 = memref.dim %large_input, %c1 : tensor // to pad, to tile + %d2 = memref.dim %large_input, %c2 : tensor // to pad, not to tile + %d3 = memref.dim %large_input, %c3 : tensor // not to pad, not to tile + + %pad = linalg.pad_tensor %small_input low[%c0, %input_low_pad, %input_low_pad, %c0] high[%c0, %input_high_pad, %input_high_pad, %c0] { + ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index): + linalg.yield %zero : f32 + } : tensor to tensor + + %fill = linalg.fill(%large_input, %zero) : tensor, f32 -> tensor + + %for0 = scf.for %iv0 = %c0 to %d0 step %c4 iter_args(%arg0 = %fill) -> tensor { + %for1 = scf.for %iv1 = %c0 to %d1 step %c8 iter_args(%arg1 = %arg0) -> tensor { + %d0_size = affine.min affine_map<(d0)[s0] -> (4, -d0 + s0)>(%iv0)[%d0] + %d1_size = affine.min affine_map<(d0)[s0] -> (8, -d0 + s0)>(%iv1)[%d1] + %0 = subtensor %pad[%iv0, %iv1, 0, 0][%d0_size, %d1_size, %d2, %d3][1, 1, 1, 1] : tensor to tensor + %1 = subtensor %large_input[%iv0, %iv1, 0, 0][%d0_size, %d1_size, %d2, %d3][1, 1, 1, 1] : tensor to tensor + %2 = subtensor %arg1[%iv0, %iv1, 0, 0][%d0_size, %d1_size, %d2, %d3][1, 1, 1, 1] : tensor to tensor + + %add = linalg.generic + {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0, %1 : tensor, tensor) outs(%2 : tensor) { + ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): + %result = addf %arg4, %arg5 : f32 + linalg.yield %result : f32 + } -> tensor + + %insert = subtensor_insert %add into %arg1[%iv0, %iv1, 0, 0] [%d0_size, %d1_size, %d2, %d3] [1, 1, 1, 1] : tensor into tensor + scf.yield %insert : tensor + } + scf.yield %for1 : tensor + } + return %for0 : tensor +} + +// CHECK-DAG: #[[DIM2_UNUSED_LOW_MAP:.+]] = affine_map<()[s0] -> (s0, 0)> +// CHECK-DAG: #[[DIM2_USED_LOW_MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0, s1 - s2)> +// CHECK-DAG: #[[DIM2_UNUSED_HIGH_MAP:.+]] = affine_map<()[s0, s1, s2, s3] -> (s0, s0 - s1 + s2 + s3)> +// CHECK-DAG: #[[DIM2_OFFSET_MAP:.+]] = affine_map<()[s0] -> (-s0)> +// CHECK-DAG: #[[DIM3_SIZE_MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)> +// CHECK-DAG: #[[DIM3_UNUSED_HIGH_MAP:.+]] = affine_map<()[s0, s1] -> (0, -s0 + s1)> +// CHECK-DAG: #[[DIM3_USED_HIGH_MAP:.+]] = affine_map<()[s0, s1] -> (s0, -s1)> +// CHECK-DAG: #[[DIM0_UNUSED_LOW_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> +// CHECK-DAG: #[[DIM0_USED_LOW_MAP:.+]] = affine_map<(d0)[s0] -> (0, 4, -d0 + s0)> +// CHECK-DAG: #[[DIM0_UNUSED_HIGH_MAP:.+]] = affine_map<(d0, d1)[s0] -> (0, -d0 - d1 + s0)> +// CHECK-DAG: #[[DIM0_USED_HIGH_MAP:.+]] = affine_map<(d0, d1)[s0] -> (-d1, 4, -d0 + s0)> +// CHECK-DAG: #[[DIM0_SIZE_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 - d1 - d2)> +// CHECK-DAG: #[[DIM1_TILE_SIZE_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)> +// CHECK-DAG: #[[DIM1_UNUSED_LOW_MAP:.+]] = affine_map<(d0)[s0] -> (s0, d0)> +// CHECK-DAG: #[[DIM1_USED_MAP:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d1 + s1, 8, -d0 + s0)> +// CHECK-DAG: #[[DIM1_UNUSED_HIGH_MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (s0, -d0 - d1 + s0 + s1 + s2)> +// CHECK-DAG: #[[DIM1_OFFSET_MAP:.+]] = affine_map<(d0, d1) -> (d0 - d1)> + +// CHECK: func @pad_generic_dynamic +// CHECK-SAME: %[[SMALL_INPUT:.+]]: tensor, %[[LARGE_INPUT:.+]]: tensor +// CHECK-SAME: %[[INPUT_LOW_PAD:.+]]: index, %[[INPUT_HIGH_PAD:.+]]: index + +// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK: %[[LARGE_DIM0:.+]] = memref.dim %[[LARGE_INPUT]], %[[C0]] +// CHECK: %[[LARGE_DIM1:.+]] = memref.dim %[[LARGE_INPUT]], %[[C1]] +// CHECK: %[[LARGE_DIM2:.+]] = memref.dim %[[LARGE_INPUT]], %[[C2]] +// CHECK: %[[LARGE_DIM3:.+]] = memref.dim %[[LARGE_INPUT]], %[[C3]] +// CHECK: %[[SMALL_DIM0:.+]] = memref.dim %[[SMALL_INPUT]], %[[C0]] +// CHECK: %[[SMALL_DIM1:.+]] = memref.dim %[[SMALL_INPUT]], %[[C1]] +// CHECK: %[[SMALL_DIM2:.+]] = memref.dim %[[SMALL_INPUT]], %[[C2]] +// CHECK: %[[DIM2_UNUSED_LOW_PAD:.+]] = affine.min #[[DIM2_UNUSED_LOW_MAP]]()[%[[INPUT_LOW_PAD]]] +// CHECK: %[[DIM2_USED_LOW_PAD:.+]] = affine.min #[[DIM2_USED_LOW_MAP]]()[%[[LARGE_DIM2]], %[[INPUT_LOW_PAD]], %[[DIM2_UNUSED_LOW_PAD]]] +// CHECK: %[[DIM2_UNUSED_HIGH_PAD:.+]] = affine.min #[[DIM2_UNUSED_HIGH_MAP]]()[%[[INPUT_HIGH_PAD]], %[[LARGE_DIM2]], %[[SMALL_DIM2]], %[[INPUT_LOW_PAD]]] +// CHECK: %[[DIM2_USED_HIGH_PAD:.+]] = affine.min #[[DIM2_USED_LOW_MAP]]()[%[[LARGE_DIM2]], %[[INPUT_HIGH_PAD]], %[[DIM2_UNUSED_HIGH_PAD]]] +// CHECK: %[[DIM2_SRC_OFFSET:.+]] = affine.apply #[[DIM2_OFFSET_MAP]]()[%[[DIM2_UNUSED_LOW_PAD]]] +// CHECK: %[[DIM2_SRC_SIZE:.+]] = affine.apply #[[DIM3_SIZE_MAP]]()[%[[LARGE_DIM2]], %[[DIM2_USED_LOW_PAD]], %[[DIM2_USED_HIGH_PAD]]] +// CHECK: %[[SMALL_DIM3:.+]] = memref.dim %[[SMALL_INPUT]], %[[C3]] +// CHECK: %[[DIM3_USED_LOW_PAD:.+]] = affine.min #[[DIM2_UNUSED_LOW_MAP]]()[%[[LARGE_DIM3]]] +// CHECK: %[[DIM3_UNUSED_HIGH_PAD:.+]] = affine.min #[[DIM3_UNUSED_HIGH_MAP]]()[%[[LARGE_DIM3]], %[[SMALL_DIM3]]] +// CHECK: %[[DIM3_USED_HIGH_PAD:.+]] = affine.min #[[DIM3_USED_HIGH_MAP]]()[%[[LARGE_DIM3]], %[[DIM3_UNUSED_HIGH_PAD]]] +// CHECK: %[[DIM3_SRC_SIZE:.+]] = affine.apply #[[DIM3_SIZE_MAP]]()[%[[LARGE_DIM3]], %[[DIM3_USED_LOW_PAD]], %[[DIM3_USED_HIGH_PAD]]] +// CHECK: scf.for %[[IV0:[a-z0-9]+]] +// CHECK: %[[DIM0_UNUSED_LOW_PAD:.+]] = affine.min #[[DIM0_UNUSED_LOW_MAP]](%[[IV0]])[%[[LARGE_DIM0]]] +// CHECK: %[[DIM0_USED_LOW_PAD:.+]] = affine.min #[[DIM0_USED_LOW_MAP]](%[[IV0]])[%[[LARGE_DIM0]]] +// CHECK: %[[DIM0_UNUSED_HIGH_PAD:.+]] = affine.min #[[DIM0_UNUSED_HIGH_MAP]](%[[IV0]], %[[DIM0_UNUSED_LOW_PAD]])[%[[SMALL_DIM0]]] +// CHECK: %[[DIM0_USED_HIGH_PAD:.+]] = affine.min #[[DIM0_USED_HIGH_MAP]](%[[IV0]], %[[DIM0_UNUSED_HIGH_PAD]])[%[[LARGE_DIM0]]] +// CHECK: %[[DIM0_SRC_SIZE:.+]] = affine.apply #[[DIM0_SIZE_MAP]](%[[DIM0_UNUSED_LOW_PAD]], %[[DIM0_USED_LOW_PAD]], %[[DIM0_USED_HIGH_PAD]]) +// CHECK: scf.for %[[IV1:[a-z0-9]+]] +// CHECK: %[[DIM1_TILE_SIZE:.+]] = affine.min #[[DIM1_TILE_SIZE_MAP]](%[[IV1]])[%[[LARGE_DIM1]]] +// CHECK: %[[DIM1_UNUSED_LOW_PAD:.+]] = affine.min #[[DIM1_UNUSED_LOW_MAP]](%[[IV1]])[%[[INPUT_LOW_PAD]]] +// CHECK: %[[DIM1_USED_LOW_PAD:.+]] = affine.min #[[DIM1_USED_MAP]](%[[IV1]], %[[DIM1_UNUSED_LOW_PAD]])[%[[LARGE_DIM1]], %[[INPUT_LOW_PAD]]] +// CHECK: %[[DIM1_UNUSED_HIGH_PAD:.+]] = affine.min #[[DIM1_UNUSED_HIGH_MAP]](%[[IV1]], %[[DIM1_TILE_SIZE]])[%[[INPUT_HIGH_PAD]], %[[SMALL_DIM1]], %[[INPUT_LOW_PAD]]] +// CHECK: %[[DIM1_USED_HIGH_PAD:.+]] = affine.min #[[DIM1_USED_MAP]](%[[IV1]], %[[DIM1_UNUSED_HIGH_PAD]])[%[[LARGE_DIM1]], %[[INPUT_HIGH_PAD]]] +// CHECK: %[[DIM1_SRC_OFFSET:.+]] = affine.apply #[[DIM1_OFFSET_MAP]](%[[IV1]], %[[DIM1_UNUSED_LOW_PAD]]) +// CHECK: %[[DIM1_SRC_SIZE:.+]] = affine.apply #[[DIM0_SIZE_MAP]](%[[DIM1_TILE_SIZE]], %[[DIM1_USED_LOW_PAD]], %[[DIM1_USED_HIGH_PAD]]) +// CHECK: %[[SRC_SUBTENSOR:.+]] = subtensor %[[SMALL_INPUT]] +// CHECK-SAME: [%[[IV0]], %[[DIM1_SRC_OFFSET]], %[[DIM2_SRC_OFFSET]], 0] +// CHECK-SAME: [%[[DIM0_SRC_SIZE]], %[[DIM1_SRC_SIZE]], %[[DIM2_SRC_SIZE]], %[[DIM3_SRC_SIZE]]] +// CHECK: %[[PAD:.+]] = linalg.pad_tensor %[[SRC_SUBTENSOR]] +// CHECK-SAME: low[%[[DIM0_USED_LOW_PAD]], %[[DIM1_USED_LOW_PAD]], %[[DIM2_USED_LOW_PAD]], %[[DIM3_USED_LOW_PAD]]] +// CHECK-SAME: high[%[[DIM0_USED_HIGH_PAD]], %[[DIM1_USED_HIGH_PAD]], %[[DIM2_USED_HIGH_PAD]], %[[DIM3_USED_HIGH_PAD]]] +// CHECK: linalg.yield %[[ZERO]] +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[PAD]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -155,11 +155,13 @@ linalg::Aliases aliases; linalg::LinalgDependenceGraph graph(aliases, linalgOps); if (auto info = fuseProducerOfBuffer(b, opOperand, graph)) { - auto *originalOp = info->originalProducer.getOperation(); + auto *originalOp = info->originalProducer; eraseSet.insert(originalOp); - auto *originalOpInLinalgOpsVector = - std::find(linalgOps.begin(), linalgOps.end(), originalOp); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + if (auto linalgProducer = dyn_cast(originalOp)) { + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer; + } changed = true; } } else { @@ -168,10 +170,12 @@ if (opOperand.getOperandNumber() >= linalgOp.getNumInputs()) continue; if (auto info = fuseProducerOfTensor(b, opOperand)) { - auto *originalOp = info->originalProducer.getOperation(); - auto *originalOpInLinalgOpsVector = - std::find(linalgOps.begin(), linalgOps.end(), originalOp); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + auto *originalOp = info->originalProducer; + if (auto linalgProducer = dyn_cast(originalOp)) { + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer; + } // Don't mark for erasure in the tensor case, let DCE handle this. changed = true; }