diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -56,6 +56,25 @@ /// Otherwise return nullptr. IntegerAttr getSmallestBoundingIndex(Value size); +/// Create an ExtractSliceOp and, if `source` is defined by an ExtractSliceOp +/// chain, fold it by adding the offsets and multiplying the strides. +/// +/// Example: +/// ``` +/// %0 = tensor.extract_slice %arg0[3, 0][3, 32][1, 2] : tensor<64x64xf32> to +/// tensor<3x32xf32> +/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 3] : tensor<3x32xf32> to +/// tensor<3x4xf32> +/// ``` +/// folds into: +/// ``` +/// %1 = tensor.extract_slice %arg0[3, 0] [3, 4] [1, 6] : tensor<64x64xf32> to +/// tensor<3x4xf32> +/// ``` +tensor::ExtractSliceOp makeComposedExtractSliceOp( + OpBuilder &b, Location loc, Value source, ArrayRef offsets, + ArrayRef sizes, ArrayRef strides); + //===----------------------------------------------------------------------===// // Fusion utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "linalg-utils" @@ -194,6 +195,58 @@ return nullptr; } +tensor::ExtractSliceOp makeComposedExtractSliceOp( + OpBuilder &b, Location loc, Value source, ArrayRef offsets, + ArrayRef sizes, ArrayRef strides) { + AffineExpr dim1, dim2; + AffineExpr sym1, sym2; + bindDims(b.getContext(), dim1, dim2); + bindSymbols(b.getContext(), sym1, sym2); + + // Helper function to convert fold results into values. + auto getOpFoldResultAsValue = [&](OpFoldResult ofr) -> Value { + if (auto result = ofr.dyn_cast()) + return result; + assert(ofr.dyn_cast().isa() && + "expect only integer op fold results"); + APInt apInt = ofr.dyn_cast().cast().getValue(); + return b.create(loc, apInt.getSExtValue()); + }; + + // Accumulate the `source` offsets and strides. + SmallVector foldedOffsets(offsets.begin(), offsets.end()); + SmallVector foldedStrides(strides.begin(), strides.end()); + assert(source && "expect source to be nonezero"); + while (auto producerOp = source.getDefiningOp()) { + // Ensure the producer is not rank reducing. + if (producerOp.getSourceType().getRank() != offsets.size()) + break; + // Add the producer offsets. + for (auto en : enumerate(producerOp.getMixedOffsets())) { + SmallVector offsetValues = { + getOpFoldResultAsValue(en.value()), + getOpFoldResultAsValue(foldedOffsets[en.index()])}; + foldedOffsets[en.index()] = + makeComposedAffineApply(b, loc, dim1 + dim2, offsetValues) + .getResult(); + } + // Multiply the producer strides. + for (auto en : enumerate(producerOp.getMixedStrides())) { + SmallVector strideValues = { + getOpFoldResultAsValue(en.value()), + getOpFoldResultAsValue(foldedStrides[en.index()])}; + foldedStrides[en.index()] = + makeComposedAffineApply(b, loc, sym1 * sym2, strideValues) + .getResult(); + } + source = producerOp.source(); + } + + // Create the ExtractSliceOp using the folded offsets and strides. + return b.create(loc, source, foldedOffsets, sizes, + foldedStrides); +} + /// Specialization to build an scf "for" nest. template <> void GenerateLoopNest::doit( @@ -588,15 +641,18 @@ strides.push_back(builder.getIndexAttr(1)); } - Operation *sliceOp = shapedType.isa() - ? builder - .create( - loc, valueToTile, offsets, sizes, strides) - .getOperation() - : builder - .create( - loc, valueToTile, offsets, sizes, strides) - .getOperation(); + auto *sliceOp = TypeSwitch(shapedType) + .Case([&](MemRefType) { + return builder.create( + loc, valueToTile, offsets, sizes, strides); + }) + .Case([&](RankedTensorType) { + return makeComposedExtractSliceOp( + builder, loc, valueToTile, offsets, sizes, strides); + }) + .Default([](ShapedType) -> Operation * { + llvm_unreachable("Unexpected shaped type"); + }); return sliceOp->getResult(0); } diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -130,3 +130,49 @@ // TLOOP-SAME: ins (%{{.*}} = %[[ARG_0]]: [[TY]], %{{.*}} = %[[ARG_1]]: [[TY]]) // TLOOP-SAME: outs (%{{.*}} = %[[INIT]]: [[TY]]) // TLOOP-SAME: distribution["block_x", "block_y", "none"] { + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 + 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 + 4)> + +// CHECK: fold_extract_slice +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor +func @fold_extract_slice( + %arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + + // CHECK: %[[C0:.*]] = constant 0 + %c0 = constant 0 : index + + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] + %0 = tensor.dim %arg1, %c0 : tensor + %1 = tensor.extract_slice %arg0[3, 4] [%0, 42] [%0, 2] : tensor to tensor + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + + // Fold the existing extract slice op into the one created by the tiling. + // CHECK: %[[SIZE0:.*]] = affine.min #[[MAP0]](%[[IV0]])[%[[DIM]] + // CHECK: %[[OFF0:.*]] = affine.apply #[[MAP1]](%[[IV0]] + // CHECK: %[[OFF1:.*]] = affine.apply #[[MAP2]](%[[IV1]] + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[OFF0]], %[[OFF1]] + // CHECK-SAME: %[[SIZE0]], 3 + // CHECK-SAME: %[[DIM]], 2 + // CHECK: {{.*}} = linalg.generic {{.*}} ins(%[[T0]] + %2 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%1, %arg2 : tensor, tensor) + outs(%arg1 : tensor) { + ^bb0(%arg3 : f32, %arg4: f32, %arg5: f32): + %5 = addf %arg3, %arg5 : f32 + linalg.yield %5 : f32 + } -> tensor + return %2 : tensor +} +