diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -56,6 +56,25 @@ /// Otherwise return nullptr. IntegerAttr getSmallestBoundingIndex(Value size); +/// Create an ExtractSliceOp and, if `source` is defined by an ExtractSliceOp, +/// fold it by adding the offsets. +/// +/// Example: +/// ``` +/// %0 = tensor.extract_slice %arg0[3, 4][3, 32][1, 1] : tensor<64x64xf32> to +/// tensor<3x32xf32> +/// %1 = tensor.extract_slice %0[0, 5][3, 4][1, 1] : tensor<3x32xf32> to +/// tensor<3x4xf32> +/// ``` +/// folds into: +/// ``` +/// %1 = tensor.extract_slice %arg0[3, 9][3, 4][1, 1] : tensor<64x64xf32> to +/// tensor<3x4xf32> +/// ``` +tensor::ExtractSliceOp makeComposedExtractSliceOp( + OpBuilder &b, Location loc, Value source, ArrayRef offsets, + ArrayRef sizes, ArrayRef strides); + //===----------------------------------------------------------------------===// // Fusion utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -72,6 +72,12 @@ } }; +/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to +/// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute. +/// Other attribute types are not supported. +Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, + OpFoldResult ofr); + /// Helper struct to build simple arithmetic quantities with minimal type /// inference support. struct ArithBuilder { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "linalg-utils" @@ -194,6 +195,48 @@ return nullptr; } +tensor::ExtractSliceOp makeComposedExtractSliceOp( + OpBuilder &b, Location loc, Value source, ArrayRef offsets, + ArrayRef sizes, ArrayRef strides) { + assert(source && "expect source to be nonzero"); + + // Do not fold if the producer is not an ExtractSliceOp. + auto producerOp = source.getDefiningOp(); + if (!producerOp) + return b.create(loc, source, offsets, sizes, + strides); + + // Do not fold if the producer is rank reducing or if there are any non-unit + // strides. Supporting non-unit strides complicates the offset computation + // since the consumer offsets need to be multiplied by the producer strides. + // TODO: support non-unit strides once there are use cases. + SmallVector allStrides = producerOp.getMixedStrides(); + allStrides.append(strides.begin(), strides.end()); + bool hasNonUnitStride = any_of(allStrides, [](OpFoldResult ofr) { + return getConstantIntValue(ofr) != static_cast(1); + }); + if (hasNonUnitStride || + producerOp.getSourceType().getRank() != + producerOp.getResult().getType().cast().getRank()) + return b.create(loc, source, offsets, sizes, + strides); + + // Fold the producer by adding the offests and extracting the slice directly + // from the producer source tensor. + SmallVector foldedOffsets(offsets.begin(), offsets.end()); + AffineExpr dim1, dim2; + bindDims(b.getContext(), dim1, dim2); + for (auto en : enumerate(producerOp.getMixedOffsets())) { + SmallVector offsetValues = { + getValueOrCreateConstantIndexOp(b, loc, foldedOffsets[en.index()]), + getValueOrCreateConstantIndexOp(b, loc, en.value())}; + foldedOffsets[en.index()] = + makeComposedAffineApply(b, loc, dim1 + dim2, offsetValues).getResult(); + } + return b.create(loc, producerOp.source(), + foldedOffsets, sizes, strides); +} + /// Specialization to build an scf "for" nest. template <> void GenerateLoopNest::doit( @@ -603,15 +646,18 @@ strides.push_back(builder.getIndexAttr(1)); } - Operation *sliceOp = shapedType.isa() - ? builder - .create( - loc, valueToTile, offsets, sizes, strides) - .getOperation() - : builder - .create( - loc, valueToTile, offsets, sizes, strides) - .getOperation(); + auto *sliceOp = TypeSwitch(shapedType) + .Case([&](MemRefType) { + return builder.create( + loc, valueToTile, offsets, sizes, strides); + }) + .Case([&](RankedTensorType) { + return makeComposedExtractSliceOp( + builder, loc, valueToTile, offsets, sizes, strides); + }) + .Default([](ShapedType) -> Operation * { + llvm_unreachable("Unexpected shaped type"); + }); return sliceOp->getResult(0); } diff --git a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp --- a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp +++ b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp @@ -49,6 +49,15 @@ } } +Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, + OpFoldResult ofr) { + if (auto value = ofr.dyn_cast()) + return value; + auto attr = ofr.dyn_cast().dyn_cast(); + assert(attr && "expect the op fold result casts to an integer attribute"); + return b.create(loc, attr.getValue().getSExtValue()); +} + Value ArithBuilder::_and(Value lhs, Value rhs) { return b.create(loc, lhs, rhs); } diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -130,3 +130,49 @@ // TLOOP-SAME: ins (%{{.*}} = %[[ARG_0]]: [[TY]], %{{.*}} = %[[ARG_1]]: [[TY]]) // TLOOP-SAME: outs (%{{.*}} = %[[INIT]]: [[TY]]) // TLOOP-SAME: distribution["block_x", "block_y", "none"] { + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 + 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 + 4)> + +// CHECK: fold_extract_slice +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor +func @fold_extract_slice( + %arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + + // CHECK: %[[C0:.*]] = constant 0 + %c0 = constant 0 : index + + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] + %0 = tensor.dim %arg1, %c0 : tensor + %1 = tensor.extract_slice %arg0[3, 4] [%0, 42] [1, 1] : tensor to tensor + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + + // Fold the existing extract slice op into the one created by the tiling. + // CHECK: %[[SIZE0:.*]] = affine.min #[[MAP0]](%[[IV0]])[%[[DIM]] + // CHECK: %[[OFF0:.*]] = affine.apply #[[MAP1]](%[[IV0]] + // CHECK: %[[OFF1:.*]] = affine.apply #[[MAP2]](%[[IV1]] + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[OFF0]], %[[OFF1]] + // CHECK-SAME: %[[SIZE0]], 3 + // CHECK-SAME: 1, 1 + // CHECK: {{.*}} = linalg.generic {{.*}} ins(%[[T0]] + %2 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%1, %arg2 : tensor, tensor) + outs(%arg1 : tensor) { + ^bb0(%arg3 : f32, %arg4: f32, %arg5: f32): + %5 = addf %arg3, %arg5 : f32 + linalg.yield %5 : f32 + } -> tensor + return %2 : tensor +} +