diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -1219,4 +1219,47 @@ let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// DelinearizeIndexOp +//===----------------------------------------------------------------------===// + +def DelinearizeIndexOp : Op { + let summary = "delinearize an index"; + let description = [{ + The `arith.delinearize_index` operation takes a single index value and + calculates the multi-index according to the given basis. + + Example: + + ``` + %indices:3 = arith.delinearize_index %linear_index (%1, %2, %3) : index, index, index + ``` + + In the above example, `%indices:3` conceptually holds the following: + + ``` + %v1 = arith.muli %1, %2 : index + %indices#0 = floorDiv(%linear_index , %v1) + %indices#1 = floorDiv(remander(%linear_index , %v1), %3) + %indices#2 = remainder(remainder(%linear_idnex, %v1), %3) + ``` + }]; + + let arguments = (ins Index:$linear_index, Variadic:$basis); + let results = (outs Variadic:$multi_index); + + let assemblyFormat = [{ + $linear_index `(` $basis `)` attr-dict `:` type($multi_index) + }]; + + let builders = [ + OpBuilder<(ins "Value":$linear_index, "ArrayRef":$basis)> + ]; + + let hasVerifier = 1; +} + + + #endif // ARITHMETIC_OPS diff --git a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h @@ -108,6 +108,22 @@ OpBuilder &b; Location loc; }; + +/// Holds the result of (div a, b) and (mod a, b) +struct DivModValue { + Value quotient; + Value remainder; +}; + +/// Create IR to calculate (div a, b) and (mod a, b) +DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs); + +/// Generate the IR to delinearize `linearIndex` given the `basis` and return +/// the multi-index. +FailureOr> delinearizeIndex(OpBuilder &b, Location loc, + Value linearIndex, + ArrayRef dimSizes); + } // namespace mlir #endif // MLIR_DIALECT_ARITHMETIC_UTILS_UTILS_H diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" @@ -2040,6 +2041,35 @@ return nullptr; } +//===----------------------------------------------------------------------===// +// DelinearizeIndexOp +//===----------------------------------------------------------------------===// + +void arith::DelinearizeIndexOp::build(OpBuilder &builder, + OperationState &result, + Value linear_index, + ArrayRef basis) { + result.addTypes(SmallVector(basis.size(), builder.getIndexType())); + result.addOperands(linear_index); + SmallVector basisValues = + llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value { + Optional staticDim = getConstantIntValue(ofr); + if (staticDim.hasValue()) + return builder.create(result.location, + *staticDim); + return ofr.dyn_cast(); + })); + result.addOperands(basisValues); +} + +LogicalResult arith::DelinearizeIndexOp::verify() { + if (getBasis().empty()) + return emitOpError("basis should not be empty"); + if (getNumResults() != getBasis().size()) + return emitOpError("should return an index for each basis element"); + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt @@ -15,6 +15,7 @@ LINK_LIBS PUBLIC MLIRDialect + MLIRDialectUtils MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRAnalysis MLIRArithmeticDialect + MLIRArithmeticUtils MLIRBufferizationDialect MLIRBufferizationTransforms MLIRInferIntRangeInterface diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp @@ -9,6 +9,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -188,6 +189,23 @@ } }; +/// Lowers `arith.delinearize_index` into a sequence of division and remainder +/// operations. +struct LowerDelinearizeIndexOps + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::DelinearizeIndexOp op, + PatternRewriter &rewriter) const override { + FailureOr> multiIndex = + delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(), + llvm::to_vector(op.getBasis())); + if (failed(multiIndex)) + return failure(); + rewriter.replaceOp(op, *multiIndex); + return success(); + } +}; + struct ArithmeticExpandOpsPass : public ArithmeticExpandOpsBase { void runOnOperation() override { @@ -207,7 +225,8 @@ arith::MaxUIOp, arith::MinFOp, arith::MinSIOp, - arith::MinUIOp + arith::MinUIOp, + arith::DelinearizeIndexOp >(); // clang-format on if (failed(applyPartialConversion(getOperation(), target, @@ -230,7 +249,8 @@ MaxMinIOpConverter, MaxMinIOpConverter, MaxMinIOpConverter, - MaxMinIOpConverter + MaxMinIOpConverter, + LowerDelinearizeIndexOps >(patterns.getContext()); // clang-format on } diff --git a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/IR/OpDefinition.h" #include "llvm/ADT/SmallBitVector.h" using namespace mlir; @@ -115,3 +116,47 @@ Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { return b.create(loc, cmp, lhs, rhs); } + +DivModValue mlir::getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs) { + DivModValue result; + result.quotient = b.create(loc, lhs, rhs); + result.remainder = b.create(loc, lhs, rhs); + return result; +} + +/// Create IR that computes the product of all elements in the set. +static FailureOr getIndexProduct(OpBuilder &b, Location loc, + ArrayRef set) { + if (set.empty()) + return failure(); + OpFoldResult result = set[0]; + for (unsigned i = 1; i < set.size(); i++) + result = b.createOrFold( + loc, getValueOrCreateConstantIndexOp(b, loc, result), set[i]); + return result; +} + +FailureOr> mlir::delinearizeIndex(OpBuilder &b, Location loc, + Value linearIndex, + ArrayRef dimSizes) { + unsigned numDims = dimSizes.size(); + + SmallVector divisors; + for (unsigned i = 1; i < numDims; i++) { + ArrayRef slice(dimSizes.begin() + i, dimSizes.end()); + FailureOr prod = getIndexProduct(b, loc, slice); + if (failed(prod)) + return failure(); + divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod)); + } + + SmallVector results; + Value residual = linearIndex; + for (Value divisor : divisors) { + DivModValue divMod = getDivMod(b, loc, residual, divisor); + results.push_back(divMod.quotient); + residual = divMod.remainder; + } + results.push_back(residual); + return results; +} diff --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir --- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir +++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir @@ -230,3 +230,24 @@ } // CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32 + +// ----- + +// CHECK-LABEL: @static_basis +// CHECK-SAME: (%[[IDX:.+]]: index) +// CHECK: arith.constant +// CHECK: arith.constant +// CHECK-DAG: %[[c224:.+]] = arith.constant 224 : index +// CHECK-DAG: %[[c50176:.+]] = arith.constant 50176 : index +// CHECK: %[[N:.+]] = arith.divui %[[IDX]], %[[c50176]] : index +// CHECK: %[[RES:.+]] = arith.remui %[[IDX]], %[[c50176]] : index +// CHECK: %[[P:.+]] = arith.divui %[[RES]], %[[c224]] : index +// CHECK: %[[Q:.+]] = arith.remui %[[RES]], %[[c224]] : index +// CHECK: return %[[N]], %[[P]], %[[Q]] +func.func @static_basis(%linear_index: index) -> (index, index, index) { + %b0 = arith.constant 16 : index + %b1 = arith.constant 224 : index + %b2 = arith.constant 224 : index + %1:3 = arith.delinearize_index %linear_index (%b0, %b1, %b2) : index, index, index + return %1#0, %1#1, %1#2 : index, index, index +} diff --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir --- a/mlir/test/Dialect/Arithmetic/invalid.mlir +++ b/mlir/test/Dialect/Arithmetic/invalid.mlir @@ -721,3 +721,19 @@ %x = arith.constant 1 : i32 } + +// ----- + +func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) { + // expected-error@+1 {{'arith.delinearize_index' op should return an index for each basis element}} + %1 = arith.delinearize_index %idx (%basis0, %basis1) : index + return +} + +// ----- + +func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) { + // expected-error@+1 {{'arith.delinearize_index' op basis should not be empty}} + arith.delinearize_index %idx () : index + return +} diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir --- a/mlir/test/Dialect/Arithmetic/ops.mlir +++ b/mlir/test/Dialect/Arithmetic/ops.mlir @@ -952,3 +952,9 @@ %min_unsigned = arith.minui %i1, %i2 : i32 return } + +// CHECK-LABEL: func @delinearize +func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) -> (index, index) { + %1:2 = arith.delinearize_index %idx (%basis0, %basis1) : index, index + return %1#0, %1#1 : index, index +}