diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -19,6 +19,25 @@ // ConvertToLLVMPattern //===----------------------------------------------------------------------===// +/// First, scale each value by `scale` using ScaleOpTy. +/// Then, perform a tree reduction of `values` in-place using OpTy. +/// At the end of the function, `values[0]` contains the desired result. +template +void binaryTreeReduce(RewriterBase &rewriter, Location loc, + SmallVectorImpl &values) { + // Invariant: always bin_op(A[i], A[i + step / 2]). + // step == 2: {0, 1}, {2, 3} .. { k, k + 1} + // step == 4: {0, 2}, {4, 6} .. {2p, 2p + 2} + // step == 8: {0, 4}, {8, 12} .. {4l, 4l + 4} + if (values.size() > 1) { + for (int64_t step = 1, e = values.size(); step < e; step *= 2) { + for (int64_t i = 0; i + step < e; i += 2 * step) { + values[i] = rewriter.create(loc, values[i], values[i + step]); + } + } + } +} + ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, @@ -74,30 +93,43 @@ MemRefDescriptor memRefDescriptor(memRefDesc); Value base = memRefDescriptor.alignedPtr(rewriter, loc); + SmallVector ptrSubIndices; + ptrSubIndices.reserve(indices.size()); + Value index; - if (offset != 0) // Skip if offset is zero. + // Skip if offset is zero. + if (offset != 0) { index = ShapedType::isDynamic(offset) ? memRefDescriptor.offset(rewriter, loc) : createIndexConstant(rewriter, loc, offset); + ptrSubIndices.push_back(index); + } - for (int i = 0, e = indices.size(); i < e; ++i) { + for (int64_t i = 0, e = indices.size(); i < e; ++i) { Value increment = indices[i]; if (strides[i] != 1) { // Skip if stride is 1. Value stride = ShapedType::isDynamic(strides[i]) ? memRefDescriptor.stride(rewriter, loc, i) : createIndexConstant(rewriter, loc, strides[i]); + // TODO: scale by element type bitwidth / byte bitwidth if going through + // bitcast to i8. increment = rewriter.create(loc, increment, stride); } - index = - index ? rewriter.create(loc, index, increment) : increment; + ptrSubIndices.push_back(increment); } + // Use a binary tree reduction to break out sequential dependence chains that + // may break address computation hoisting. + binaryTreeReduce(rewriter, loc, ptrSubIndices); + + // TODO: bitcast to i8 to extract out the scale. Type elementPtrType = memRefDescriptor.getElementPtrType(); - return index ? rewriter.create( - loc, elementPtrType, - getTypeConverter()->convertType(type.getElementType()), - base, index) - : base; + return !ptrSubIndices.empty() + ? rewriter.create( + loc, elementPtrType, + getTypeConverter()->convertType(type.getElementType()), base, + ptrSubIndices[0]) + : base; } // Check if the MemRefType `type` is supported by the lowering. We currently