diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -1061,4 +1061,46 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// AffineDelinearizeIndexOp +//===----------------------------------------------------------------------===// + +def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", + [NoSideEffect]> { + let summary = "delinearize an index"; + let description = [{ + The `affine.delinearize_index` operation takes a single index value and + calculates the multi-index according to the given basis. + + Example: + + ``` + %indices:3 = affine.delinearize_index %linear_index (%1, %2, %3) : index, index, index + ``` + + In the above example, `%indices:3` conceptually holds the following: + + ``` + %v1 = arith.muli %1, %2 : index + %indices#0 = floorDiv(%linear_index , %v1) + %indices#1 = floorDiv(remander(%linear_index , %v1), %3) + %indices#2 = remainder(remainder(%linear_idnex, %v1), %3) + ``` + }]; + + let arguments = (ins Index:$linear_index, Variadic:$basis); + let results = (outs Variadic:$multi_index); + + let assemblyFormat = [{ + $linear_index `(` $basis `)` attr-dict `:` type($multi_index) + }]; + + let builders = [ + OpBuilder<(ins "Value":$linear_index, "ArrayRef":$basis)> + ]; + + let hasVerifier = 1; +} + + #endif // AFFINE_OPS diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -110,6 +110,14 @@ /// Overload relying on pass options for initialization. std::unique_ptr> createSuperVectorizePass(); +/// Populate patterns that expand affine index operations into more fundamental +/// operations (not necessarily restricted to Affine dialect). +void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns); + +/// Creates a pass to expand affine index operations into more fundamental +/// operations (not necessarily restricted to Affine dialect). +std::unique_ptr createAffineExpandIndexOpsPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -395,4 +395,9 @@ let constructor = "mlir::createSimplifyAffineStructuresPass()"; } +def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> { + let summary = "Lower affine operations operating on indices into more fundamental operations"; + let constructor = "mlir::createAffineExpandIndexOpsPass()"; +} + #endif // MLIR_DIALECT_AFFINE_PASSES diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -304,6 +304,21 @@ AffineMap affineMap, ValueRange operands); +/// Holds the result of (div a, b) and (mod a, b) +struct DivModValue { + Value quotient; + Value remainder; +}; + +/// Create IR to calculate (div a, b) and (mod a, b) +DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs); + +/// Generate the IR to delinearize `linearIndex` given the `basis` and return +/// the multi-index. +FailureOr> delinearizeIndex(OpBuilder &b, Location loc, + Value linearIndex, + ArrayRef dimSizes); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_UTILS_H diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4036,6 +4036,34 @@ return success(); } +//===----------------------------------------------------------------------===// +// DelinearizeIndexOp +//===----------------------------------------------------------------------===// + +void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result, + Value linear_index, + ArrayRef basis) { + result.addTypes(SmallVector(basis.size(), builder.getIndexType())); + result.addOperands(linear_index); + SmallVector basisValues = + llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value { + Optional staticDim = getConstantIntValue(ofr); + if (staticDim.has_value()) + return builder.create(result.location, + *staticDim); + return ofr.dyn_cast(); + })); + result.addOperands(basisValues); +} + +LogicalResult AffineDelinearizeIndexOp::verify() { + if (getBasis().empty()) + return emitOpError("basis should not be empty"); + if (getNumResults() != getBasis().size()) + return emitOpError("should return an index for each basis element"); + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRArithmeticDialect + MLIRDialectUtils MLIRIR MLIRLoopLikeInterface MLIRMemRefDialect diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -0,0 +1,66 @@ +//===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to expand affine index ops into one or more more +// fundamental operations. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Passes.h" + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +/// Lowers `arith.delinearize_index` into a sequence of division and remainder +/// operations. +struct LowerDelinearizeIndexOps + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, + PatternRewriter &rewriter) const override { + FailureOr> multiIndex = + delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(), + llvm::to_vector(op.getBasis())); + if (failed(multiIndex)) + return failure(); + rewriter.replaceOp(op, *multiIndex); + return success(); + } +}; + +class ExpandAffineIndexOpsPass + : public AffineExpandIndexOpsBase { +public: + ExpandAffineIndexOpsPass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + populateAffineExpandIndexOpsPatterns(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } + +private: +}; + +} // namespace + +void mlir::populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + +std::unique_ptr mlir::createAffineExpandIndexOpsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRAffineTransforms AffineDataCopyGeneration.cpp + AffineExpandIndexOps.cpp AffineLoopInvariantCodeMotion.cpp AffineLoopNormalize.cpp AffineParallelize.cpp diff --git a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRAffineDialect MLIRAffineAnalysis MLIRAnalysis + MLIRArithmeticUtils MLIRMemRefDialect MLIRTransformUtils ) diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineExprVisitor.h" @@ -1816,3 +1817,47 @@ return newMemRefType; } + +DivModValue mlir::getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs) { + DivModValue result; + result.quotient = b.create(loc, lhs, rhs); + result.remainder = b.create(loc, lhs, rhs); + return result; +} + +/// Create IR that computes the product of all elements in the set. +static FailureOr getIndexProduct(OpBuilder &b, Location loc, + ArrayRef set) { + if (set.empty()) + return failure(); + OpFoldResult result = set[0]; + for (unsigned i = 1; i < set.size(); i++) + result = b.createOrFold( + loc, getValueOrCreateConstantIndexOp(b, loc, result), set[i]); + return result; +} + +FailureOr> mlir::delinearizeIndex(OpBuilder &b, Location loc, + Value linearIndex, + ArrayRef dimSizes) { + unsigned numDims = dimSizes.size(); + + SmallVector divisors; + for (unsigned i = 1; i < numDims; i++) { + ArrayRef slice(dimSizes.begin() + i, dimSizes.end()); + FailureOr prod = getIndexProduct(b, loc, slice); + if (failed(prod)) + return failure(); + divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod)); + } + + SmallVector results; + Value residual = linearIndex; + for (Value divisor : divisors) { + DivModValue divMod = getDivMod(b, loc, residual, divisor); + results.push_back(divMod.quotient); + residual = divMod.remainder; + } + results.push_back(residual); + return results; +} diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s -affine-expand-index-ops -split-input-file | FileCheck %s + +// CHECK-LABEL: @static_basis +// CHECK-SAME: (%[[IDX:.+]]: index) +// CHECK-DAG: %[[c224:.+]] = arith.constant 224 : index +// CHECK-DAG: %[[c50176:.+]] = arith.constant 50176 : index +// CHECK: %[[N:.+]] = arith.divui %[[IDX]], %[[c50176]] : index +// CHECK: %[[RES:.+]] = arith.remui %[[IDX]], %[[c50176]] : index +// CHECK: %[[P:.+]] = arith.divui %[[RES]], %[[c224]] : index +// CHECK: %[[Q:.+]] = arith.remui %[[RES]], %[[c224]] : index +// CHECK: return %[[N]], %[[P]], %[[Q]] +func.func @static_basis(%linear_index: index) -> (index, index, index) { + %b0 = arith.constant 16 : index + %b1 = arith.constant 224 : index + %b2 = arith.constant 224 : index + %1:3 = affine.delinearize_index %linear_index (%b0, %b1, %b2) : index, index, index + return %1#0, %1#1, %1#2 : index, index, index +} diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -485,3 +485,19 @@ } return } + +// ----- + +func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) { + // expected-error@+1 {{'affine.delinearize_index' op should return an index for each basis element}} + %1 = affine.delinearize_index %idx (%basis0, %basis1) : index + return +} + +// ----- + +func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) { + // expected-error@+1 {{'affine.delinearize_index' op basis should not be empty}} + affine.delinearize_index %idx () : index + return +} \ No newline at end of file diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -260,3 +260,12 @@ // CHECK-NEXT: %[[res2:.*]] = arith.addf %{{.*}}, %[[iter_arg2]] : f32 // CHECK-NEXT: affine.yield %[[res1]], %[[res2]] : f32, f32 // CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @delinearize +func.func @delinearize(%linear_idx: index, %basis0: index, %basis1 :index) -> (index, index) { + // CHECK: affine.delinearize_index + %1:2 = affine.delinearize_index %linear_idx (%basis0, %basis1) : index, index + return %1#0, %1#1 : index, index +}