diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1269,4 +1269,42 @@ let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; } +//===----------------------------------------------------------------------===// +// DelinearizeIndexOp +//===----------------------------------------------------------------------===// + +def DelinearizeIndexOp : Tensor_Op<"delinearize_index", [NoSideEffect, + RangedTypesMatchWith<"result type matches operand", "basis", "multi_index", + "llvm::make_range($_self.begin(), $_self.end())"> +]> { + let summary = "delinearize an index"; + let description = [{ + The `tensor.delinearize_index` operation takes a single index value and + calculates the multi-index according to the given basis. + + Example: + ```mlir + %indices:3 = tensor.delinearize_index %linear_index (3, 4, 5) + ``` + + In the above example, `%indices:3` conceptually holds the following: + ``` + %indices#0 = floorDiv(%linear_index , 20) + %indices#1 = floorDiv(remander(%linear_index , 20), 5) + %indices#2 = remainder(remainder(%linear_idnex, 20), 5) + ``` + }]; + + let arguments = (ins Index:$linear_index, Variadic:$basis); + let results = (outs Variadic:$multi_index); + + let assemblyFormat = [{ + $linear_index `(` $basis `:` type($basis)`)` attr-dict + }]; + + let builders = [ + OpBuilder<(ins "Value":$linear_index, "ArrayRef":$basis)> + ]; +} + #endif // TENSOR_OPS diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h @@ -15,6 +15,9 @@ /// Creates an instance of `tensor` dialect bufferization pass. std::unique_ptr createTensorBufferizePass(); +/// Create a TensorLowerDelinearizeIndexPass. +std::unique_ptr createTensorLowerDelinearizeIndexPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td @@ -16,4 +16,12 @@ let constructor = "mlir::createTensorBufferizePass()"; } +def TensorLowerDelinearizeIndexPass + : Pass<"tensor-lower-delinearize-index", ""> { + let summary = "Lower delinearize indices."; + let constructor = "mlir::createTensorLowerDelinearizeIndexPass()"; + let dependentDialects = ["::mlir::arith::ArithmeticDialect"]; + let options = []; +} + #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -29,6 +29,9 @@ FailureOr replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); +/// Populates rewrite patterns that lower `tensor.delinearize_index`. +void populateLowerDelinearizeIndexPatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -33,6 +33,21 @@ SmallVector createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor); +/// Holds the result of (div a, b) and (mod a, b) +struct DivModValue { + Value quotient; + Value remainder; +}; + +/// Create IR to calculate (div a, b) and (mod a, b) +DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs); + +/// Generate the IR to delinearize `linearIndex` given the `basis` and return +/// the multi-index. +FailureOr> delinearizeIndex(OpBuilder &b, Location loc, + Value linearIndex, + ArrayRef dimSizes); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -12,11 +12,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" @@ -27,6 +29,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -34,6 +37,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #include #include diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2364,6 +2364,25 @@ return SplatElementsAttr::get(getType(), {constOperand}); } +//===----------------------------------------------------------------------===// +// DelinarizeIndexOp +//===----------------------------------------------------------------------===// +void DelinearizeIndexOp::build(OpBuilder &builder, OperationState &result, + Value linear_index, + ArrayRef basis) { + result.addTypes(SmallVector(basis.size(), builder.getIndexType())); + result.addOperands(linear_index); + SmallVector basisValues = + llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value { + Optional staticDim = getConstantIntValue(ofr); + if (staticDim.hasValue()) + return builder.create(result.location, + *staticDim); + return ofr.dyn_cast(); + })); + result.addOperands(basisValues); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + DelinearizeIndex.cpp SplitPadding.cpp SwapExtractSliceWithProducer.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/DelinearizeIndex.cpp b/mlir/lib/Dialect/Tensor/Transforms/DelinearizeIndex.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/DelinearizeIndex.cpp @@ -0,0 +1,60 @@ +//===- DelinearizeIndex.cpp - Implementation of delinearize_index lowering-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transformations and lowerings for the +// `tensor.delinearize_index` operation. +// +//===----------------------------------------------------------------------===// +#include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { +struct LowerDelinearizeIndexOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DelinearizeIndexOp op, + PatternRewriter &rewriter) const override { + FailureOr> multiIndex = delinearizeIndex( + rewriter, op->getLoc(), op.linear_index(), llvm::to_vector(op.basis())); + if (failed(multiIndex)) + return failure(); + rewriter.replaceOp(op, *multiIndex); + return success(); + } +}; + +class LowerDelinearizeIndexPass + : public TensorLowerDelinearizeIndexPassBase { +public: + LowerDelinearizeIndexPass() = default; + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void tensor::populateLowerDelinearizeIndexPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +std::unique_ptr mlir::createTensorLowerDelinearizeIndexPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h b/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h @@ -26,6 +26,10 @@ class SCFDialect; } // namespace scf +namespace arith { +class ArithmeticDialect; +} + #define GEN_PASS_CLASSES #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc" diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -68,3 +68,47 @@ } return dynamicDims; } + +DivModValue mlir::tensor::getDivMod(OpBuilder &b, Location loc, Value lhs, + Value rhs) { + DivModValue result; + result.quotient = b.create(loc, lhs, rhs); + result.remainder = b.create(loc, lhs, rhs); + return result; +} + +/// Create IR that computes the product of all elements in the set. +static FailureOr getIndexProduct(OpBuilder &b, Location loc, + ArrayRef set) { + if (set.empty()) + return failure(); + Value result = set[0]; + for (unsigned i = 1; i < set.size(); i++) + result = b.create(loc, result, set[i]); + return result; +} + +FailureOr> +mlir::tensor::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, + ArrayRef dimSizes) { + unsigned numDims = dimSizes.size(); + + SmallVector divisors; + for (unsigned i = 1; i < numDims; i++) { + ArrayRef slice(dimSizes.begin() + i, dimSizes.end()); + FailureOr prod = getIndexProduct(b, loc, slice); + if (failed(prod)) + return failure(); + divisors.push_back(*prod); + } + + SmallVector results; + Value residual = linearIndex; + for (Value divisor : divisors) { + DivModValue divMod = getDivMod(b, loc, residual, divisor); + results.push_back(divMod.quotient); + residual = divMod.remainder; + } + results.push_back(residual); + return results; +} \ No newline at end of file diff --git a/mlir/test/Dialect/Tensor/delinearize-index.mlir b/mlir/test/Dialect/Tensor/delinearize-index.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/delinearize-index.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt -split-input-file -tensor-lower-delinearize-index %s | FileCheck %s + +// CHECK: @static_basis(%[[IDX:.+]]: index) +// CHECK-DAG: %[[c50176:.+]] = arith.constant 50176 : index +// CHECK-DAG: %[[c224:.+]] = arith.constant 224 : index +// CHECK: %[[N:.+]] = arith.divui %[[IDX]], %[[c50176]] : index +// CHECK: %[[RES:.+]] = arith.remui %[[IDX]], %[[c50176]] : index +// CHECK: %[[P:.+]] = arith.divui %[[RES]], %[[c224]] : index +// CHECK: %[[Q:.+]] = arith.remui %[[RES]], %[[c224]] : index +// CHECK: return %[[N]], %[[P]], %[[Q]] +func.func @static_basis(%linear_index: index) -> (index, index, index) { + %b0 = arith.constant 16 : index + %b1 = arith.constant 224 : index + %b2 = arith.constant 224 : index + %1:3 = tensor.delinearize_index %linear_index (%b0, %b1, %b2 : index, index, index) + return %1#0, %1#1, %1#2 : index, index, index +} \ No newline at end of file diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -260,3 +260,10 @@ %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32> return } + +// ----- + +func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) -> (index, index) { + %1:2 = tensor.delinearize_index %idx (%basis0, %basis1 : index, index) + return %1#0, %1#1 : index, index +} \ No newline at end of file