diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -79,10 +79,14 @@ bool operator!() const { return expr == nullptr; } - template bool isa() const; - template U dyn_cast() const; - template U dyn_cast_or_null() const; - template U cast() const; + template + bool isa() const; + template + U dyn_cast() const; + template + U dyn_cast_or_null() const; + template + U cast() const; MLIRContext *getContext() const; @@ -251,7 +255,8 @@ raw_ostream &operator<<(raw_ostream &os, AffineExpr expr); -template bool AffineExpr::isa() const { +template +bool AffineExpr::isa() const { if (std::is_same::value) return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP; if (std::is_same::value) @@ -261,15 +266,18 @@ if (std::is_same::value) return getKind() == AffineExprKind::Constant; } -template U AffineExpr::dyn_cast() const { +template +U AffineExpr::dyn_cast() const { if (isa()) return U(expr); return U(nullptr); } -template U AffineExpr::dyn_cast_or_null() const { +template +U AffineExpr::dyn_cast_or_null() const { return (!*this || !isa()) ? U(nullptr) : U(expr); } -template U AffineExpr::cast() const { +template +U AffineExpr::cast() const { assert(isa()); return U(expr); } @@ -282,28 +290,46 @@ unsigned numSymbols); namespace detail { -template void bindDims(MLIRContext *ctx) {} +template +void bindDims(MLIRContext *ctx) {} template -void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &... exprs) { +void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) { e = getAffineDimExpr(N, ctx); bindDims(ctx, exprs...); } + +template +void bindSymbols(MLIRContext *ctx) {} + +template +void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) { + e = getAffineSymbolExpr(N, ctx); + bindSymbols(ctx, exprs...); +} } // namespace detail /// Bind a list of AffineExpr references to DimExpr at positions: /// [0 .. sizeof...(exprs)] template -void bindDims(MLIRContext *ctx, AffineExprTy &... exprs) { +void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) { detail::bindDims<0>(ctx, exprs...); } +/// Bind a list of AffineExpr references to SymbolExpr at positions: +/// [0 .. sizeof...(exprs)] +template +void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) { + detail::bindSymbols<0>(ctx, exprs...); +} + } // namespace mlir namespace llvm { // AffineExpr hash just like pointers -template <> struct DenseMapInfo { +template <> +struct DenseMapInfo { static mlir::AffineExpr getEmptyKey() { auto pointer = llvm::DenseMapInfo::getEmptyKey(); return mlir::AffineExpr(static_cast(pointer)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -417,16 +417,28 @@ return success(); } -static Value buildLoopTripCount(OpBuilder &b, Operation *op) { - MLIRContext *ctx = op->getContext(); - AffineExpr lb, ub, step = getAffineSymbolExpr(0, ctx); +/// Return the number of iterations in the loop (ub - lb).ceilDiv(step). +static Value buildLoopTripCount(OpBuilder &b, scf::ForOp forOp) { + MLIRContext *ctx = forOp->getContext(); + AffineExpr lb, ub, step; bindDims(ctx, lb, ub); - scf::ForOp forOp = cast(op); + bindSymbols(ctx, step); return b.create( - op->getLoc(), AffineMap::get(2, 1, {(ub - lb).ceilDiv(step)}, ctx), + forOp->getLoc(), AffineMap::get(2, 1, {(ub - lb).ceilDiv(step)}, ctx), ValueRange{forOp.lowerBound(), forOp.upperBound(), forOp.step()}); } +/// Return the current iteration number in the loop (iv - lb).ceilDiv(step). +static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp forOp) { + MLIRContext *ctx = forOp->getContext(); + AffineExpr iv, lb, step; + bindDims(ctx, iv, lb); + bindSymbols(ctx, step); + return b.create( + forOp->getLoc(), AffineMap::get(2, 1, {(iv - lb).ceilDiv(step)}, ctx), + ValueRange{forOp.getInductionVar(), forOp.lowerBound(), forOp.step()}); +} + LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp, unsigned nLoops) { llvm::SetVector backwardSlice, packingLoops; @@ -455,8 +467,10 @@ llvm::append_range(packedShape, paddedTensorType.getShape()); auto packedTensorType = RankedTensorType::get(packedShape, paddedTensorType.getElementType()); - auto dynamicSizes = llvm::to_vector<4>(llvm::map_range( - packingLoops, [&](Operation *op) { return buildLoopTripCount(b, op); })); + auto dynamicSizes = + llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) { + return buildLoopTripCount(b, cast(op)); + })); Value packedTensor = b.create( loc, dynamicSizes, packedTensorType.getShape(), packedTensorType.getElementType()); @@ -469,8 +483,9 @@ // 2. Create a SubTensorInsert at the top of the stack. // 3. Iteratively pop and yield the result of the SubTensorInsertOp across // the cloned loops. - SmallVector clonedLoopIvs; + SmallVector clonedLoopIvs, leadingPackedTensorIndexings; clonedLoopIvs.reserve(nLoops); + leadingPackedTensorIndexings.reserve(nLoops); BlockAndValueMapping bvm; // Stack step 1. iteratively clone loops and push `packedTensor`. // Insert `padTensorOp` into the backwardSlice so we clone it too. @@ -492,13 +507,16 @@ assert(clonedForOp->getNumRegions() == 1); clonedLoopIvs.push_back(clonedForOp.getInductionVar()); b.setInsertionPointToStart(&clonedForOp->getRegion(0).front()); + leadingPackedTensorIndexings.push_back( + buildLoopIterationCount(b, clonedForOp)); bvm.map(forOp.getInductionVar(), clonedLoopIvs.back()); packedTensor = clonedForOp.getRegionIterArgs().front(); } // Stack step 2. create SubTensorInsertOp at the top of the stack. // offsets = [clonedLoopIvs, 0 .. 0]. - SmallVector offsets(clonedLoopIvs.begin(), clonedLoopIvs.end()); + SmallVector offsets(leadingPackedTensorIndexings.begin(), + leadingPackedTensorIndexings.end()); offsets.append(paddedRank, b.getIndexAttr(0)); // sizes = [1 .. 1, paddedShape]. SmallVector sizes(nLoops, b.getIndexAttr(1)); @@ -527,12 +545,12 @@ // Now the packed tensor is ready, replace the original padding op by a // 1x..x1 SubTensor [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1]. b.setInsertionPoint(padTensorOp); - SmallVector originalLoopIvs = - llvm::to_vector<4>(llvm::map_range(packingLoops, [](Operation *loop) { - return cast(loop).getInductionVar(); + SmallVector loopIterationCounts = + llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) { + return buildLoopIterationCount(b, cast(loop)); })); // offsets = [originalLoopIvs, 0 .. 0]. - offsets.assign(originalLoopIvs.begin(), originalLoopIvs.end()); + offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end()); offsets.append(paddedRank, b.getIndexAttr(0)); // sizes = [1 .. 1, paddedShape] (definedabove). // strides = [1 .. 1] (defined above) diff --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir --- a/mlir/test/Dialect/Linalg/hoist-padding.mlir +++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir @@ -6,7 +6,15 @@ #map3 = affine_map<(d0, d1) -> (2, d0 - d1)> #map4 = affine_map<(d0, d1) -> (3, d0 - d1)> +// CHECK-DAG: #[[$DIV3:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 3)> +// CHECK-DAG: #[[$DIV4:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 4)> +// CHECK-DAG: #[[$DIVS3:[0-9a-z]+]] = affine_map<()[s0] -> (s0 ceildiv 3)> +// CHECK-DAG: #[[$DIVS4:[0-9a-z]+]] = affine_map<()[s0] -> (s0 ceildiv 4)> + // CHECK-LABEL: func @matmul_tensors +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor func @matmul_tensors( %arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor @@ -15,39 +23,60 @@ %c3 = constant 3 : index %c4 = constant 4 : index %cst = constant 0.000000e+00 : f32 + + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK-DAG: %[[C1:.*]] = constant 1 : index %c0 = constant 0 : index %c1 = constant 1 : index + + // CHECK-DAG: %[[dM:.*]] = dim %[[TA]], %[[C0]] : tensor + // CHECK-DAG: %[[dK:.*]] = dim %[[TA]], %[[C1]] : tensor + // CHECK-DAG: %[[dN:.*]] = dim %[[TB]], %[[C1]] : tensor %0 = dim %arg0, %c0 : tensor %1 = dim %arg0, %c1 : tensor %2 = dim %arg1, %c1 : tensor - // CHECK: scf.for - // CHECK: linalg.init_tensor [%{{.*}}, 2, 4] : tensor + // CHECK: scf.for %[[I:[0-9a-z]+]] = + // First padded tensor is MxKx2x4 under loop M so Kx2x4 + // CHECK: %[[SZpad0_K:[0-9]+]] = affine.apply #[[$DIVS4]]()[%[[dK]]] + // CHECK: linalg.init_tensor [%[[SZpad0_K]], 2, 4] : tensor // 1-D loop - // CHECK: %[[A:.*]] = scf.for - // CHECK-NOT: scf.for + // CHECK: %[[A:.*]] = scf.for %[[J1:[0-9a-z]+]] = + // Iteration count along J1 + // CHECK: %[[IDXpad0_K:[0-9]+]] = affine.apply #[[$DIV4]](%[[J1]]) // CHECK: subtensor %{{.*}} [1, 1] : tensor to tensor // CHECK: linalg.pad_tensor %{{.*}} // CHECK: : tensor to tensor<2x4xf32> - // CHECK: subtensor_insert %{{.*}} into %{{.*}}[%{{.*}}, 0, 0] + // CHECK: subtensor_insert %{{.*}} into %{{.*}}[%[[IDXpad0_K]], 0, 0] // CHECK-SAME: [1, 2, 4] [1, 1, 1] : tensor<2x4xf32> into tensor + // Second padded tensor is KxNx2x4 + // CHECK: %[[SZpad1_K:[0-9]+]] = affine.apply #[[$DIVS4]]()[%[[dK]]] + // CHECK: %[[SZpad1_N:[0-9]+]] = affine.apply #[[$DIVS3]]()[%[[dN]]] + // CHECK: linalg.init_tensor [%[[SZpad1_K]], %[[SZpad1_N]], 4, 3] : tensor // 2-D loop - // CHECK: linalg.init_tensor [%{{.*}}, %{{.*}}, 4, 3] : tensor - // CHECK: %[[B:.*]] = scf.for - // CHECK: scf.for - // CHECK-NOT: scf.for + // CHECK: %[[B:.*]] = scf.for %[[K2:[0-9a-z]+]] = + // Iteration count along K2 + // CHECK: %[[IDXpad1_K:[0-9]+]] = affine.apply #[[$DIV3]](%[[K2]]) + // CHECK: scf.for %[[J2:[0-9a-z]+]] = + // Iteration count along J2 + // CHECK: %[[IDXpad1_N:[0-9]+]] = affine.apply #[[$DIV4]](%[[J2]]) // CHECK: subtensor %{{.*}} [1, 1] : tensor to tensor // CHECK: linalg.pad_tensor %{{.*}} // CHECK: : tensor to tensor<4x3xf32> - // CHECK: subtensor_insert %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] + // CHECK: subtensor_insert %{{.*}} into %{{.*}}[%[[IDXpad1_K]], %[[IDXpad1_N]], 0, 0] // CHECK-SAME: [1, 1, 4, 3] [1, 1, 1, 1] : tensor<4x3xf32> into tensor // 2-D loop // CHECK: scf.for %[[J:[0-9a-zA-Z]+]] // CHECK: scf.for %[[K:[0-9a-zA-Z]+]] - // CHECK-NOT: scf.for - // CHECK: %[[stA:.*]] = subtensor %[[A]][%[[K]], 0, 0] [1, 2, 4] [1, 1, 1] : + // Iteration count along K + // CHECK: %[[IDXpad0_K:[0-9]+]] = affine.apply #[[$DIV4]](%[[K]]) + // CHECK: %[[stA:.*]] = subtensor %[[A]][%[[IDXpad0_K]], 0, 0] [1, 2, 4] [1, 1, 1] : // CHECK-SAME: tensor to tensor<2x4xf32> - // CHECK: %[[stB:.*]] = subtensor %[[B]][%[[K]], %[[J]], 0, 0] [1, 1, 4, 3] [1, 1, 1, 1] : + // Iteration count along K + // CHECK: %[[IDXpad1_K:[0-9]+]] = affine.apply #[[$DIV4]](%[[K]]) + // Iteration count along J + // CHECK: %[[IDXpad1_N:[0-9]+]] = affine.apply #[[$DIV3]](%[[J]]) + // CHECK: %[[stB:.*]] = subtensor %[[B]][%[[IDXpad1_K]], %[[IDXpad1_N]], 0, 0] [1, 1, 4, 3] [1, 1, 1, 1] : // CHECK-SAME: tensor to tensor<4x3xf32> // CHECK: %[[stC:.*]] = linalg.pad_tensor %{{.*}} // CHECK: : tensor to tensor<2x3xf32>