diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -28,9 +28,16 @@ /// Detects the `values` produced by a ConstantIndexOp and places the new /// constant in place of the corresponding sentinel value. +/// TODO(pifon2a): Remove this function and use foldDynamicIndexList. void canonicalizeSubViewPart(SmallVectorImpl &values, function_ref isDynamic); +/// Returns `success` when any of the elements in `ofrs` was produced by +/// arith::ConstantIndexOp. In that case the constant attribute replaces the +/// Value. Returns `failure` when no folding happened. +LogicalResult foldDynamicIndexList(Builder &b, + SmallVectorImpl &ofrs); + llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef shape); diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -23,8 +23,8 @@ return detail::op_matcher(); } -/// Detects the `values` produced by a ConstantIndexOp and places the new -/// constant in place of the corresponding sentinel value. +// Detects the `values` produced by a ConstantIndexOp and places the new +// constant in place of the corresponding sentinel value. void mlir::canonicalizeSubViewPart( SmallVectorImpl &values, llvm::function_ref isDynamic) { @@ -38,6 +38,25 @@ } } +// Returns `success` when any of the elements in `ofrs` was produced by +// arith::ConstantIndexOp. In that case the constant attribute replaces the +// Value. Returns `failure` when no folding happened. +LogicalResult mlir::foldDynamicIndexList(Builder &b, + SmallVectorImpl &ofrs) { + bool valuesChanged = false; + for (OpFoldResult &ofr : ofrs) { + if (ofr.is()) + continue; + // Newly static, move from Value to constant. + if (auto cstOp = + ofr.dyn_cast().getDefiningOp()) { + ofr = b.getIndexAttr(cstOp.value()); + valuesChanged = true; + } + } + return success(valuesChanged); +} + llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank, ArrayRef shape) { llvm::SmallBitVector dimsToProject(shape.size()); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1430,11 +1430,45 @@ return success(); } }; + +class ForallOpControlOperandsFolder : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForallOp op, + PatternRewriter &rewriter) const override { + SmallVector mixedLowerBound(op.getMixedLowerBound()); + SmallVector mixedUpperBound(op.getMixedUpperBound()); + SmallVector mixedStep(op.getMixedStep()); + if (failed(foldDynamicIndexList(rewriter, mixedLowerBound)) && + failed(foldDynamicIndexList(rewriter, mixedUpperBound)) && + failed(foldDynamicIndexList(rewriter, mixedStep))) + return failure(); + + SmallVector dynamicLowerBound, dynamicUpperBound, dynamicStep; + SmallVector staticLowerBound, staticUpperBound, staticStep; + dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound, + staticLowerBound); + op.getDynamicLowerBoundMutable().assign(dynamicLowerBound); + op.setStaticLowerBound(staticLowerBound); + + dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound, + staticUpperBound); + op.getDynamicUpperBoundMutable().assign(dynamicUpperBound); + op.setStaticUpperBound(staticUpperBound); + + dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep); + op.getDynamicStepMutable().assign(dynamicStep); + op.setStaticStep(staticStep); + return success(); + } +}; + } // namespace void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1497,3 +1497,28 @@ // CHECK: return %[[dim]] return %dim : index } + +// ----- + +// CHECK-LABEL: func @forall_fold_control_operands +func.func @forall_fold_control_operands( + %arg0 : tensor, %arg1: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + + %result = scf.forall (%i, %j) = (%c0, %c0) to (%dim0, %dim1) + step (%c1, %c1) shared_outs(%o = %arg1) -> (tensor) { + %slice = tensor.extract_slice %arg1[%i, %j] [1, 1] [1, 1] + : tensor to tensor<1x1xf32> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %slice into %o[%i, %j] [1, 1] [1, 1] + : tensor<1x1xf32> into tensor + } + } + + return %result : tensor +} +// CHECK: forall (%{{.*}}, %{{.*}}) in (%{{.*}}, 10) diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -658,7 +658,7 @@ // CHECK: %[[t_copy:.*]] = memref.alloc() {{.*}} : memref<10xf32> // CHECK: memref.copy %[[t]], %[[t_copy]] - // CHECK: scf.forall (%{{.*}}) in (%{{.*}}) { + // CHECK: scf.forall (%{{.*}}) in (2) { // Load from the copy and store into the shared output. // CHECK: %[[subview:.*]] = memref.subview %[[t]]