diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1286,12 +1286,16 @@ template class FoldConstantBase : public OpRewritePattern { public: + struct APIntOrFloat { + Optional apInt; + Optional apFloat; + }; struct APIntOrFloatArray { SmallVector apInts; SmallVector apFloats; }; using RegionComputationFn = - std::function; + std::function; FoldConstantBase(MLIRContext *context, const ControlElementwiseOpsFusionFn &controlFn, @@ -1403,57 +1407,82 @@ auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); auto outputShape = outputType.getShape(); + // Allocate small vectors for index delinearization. Initial values do not + // matter here as they will be overwritten later. + SmallVector indices(loopBounds.size(), 0); + SmallVector dstIndices(loopBounds.size(), 0); + SmallVector> srcIndices( + numInputs, SmallVector(loopBounds.size(), 0)); + + bool isFloat = elementType.isa(); + + // Allocate spaces for compute function inputs. Initial values do not matter + // here as they will be overwritten later. + APIntOrFloatArray computeFnInputs; + if (isFloat) { + computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); + } else { + computeFnInputs.apInts.resize(numInputs); + } + + auto inputShapes = llvm::to_vector<4>( + llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { + return operand->get().getType().cast().getShape(); + })); + + SmallVector srcLinearIndices(numInputs, 0); + // Transpose the input constant. Because we don't know its rank in advance, // we need to loop over the range [0, element count) and delinearize the // index. - for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) { - SmallVector indices(loopBounds.size(), 0); - int totalCount = linearIndex0; + for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { + int totalCount = linearIndex; for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { indices[dim] = totalCount % loopBounds[dim]; totalCount /= loopBounds[dim]; } - SmallVector> srcIndices; - for (int i = 0; i < numInputs; ++i) - srcIndices.emplace_back(loopBounds.size(), 0); - SmallVector dstIndices(loopBounds.size(), 0); - for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { for (int i = 0; i < numInputs; ++i) srcIndices[i][dim] = indices[inputDims[i][dim]]; dstIndices[dim] = indices[outputDims[dim]]; } - uint64_t linearIndex1 = dstIndices.front(); - for (int dim = 1; dim < outputType.getRank(); ++dim) - linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim]; + uint64_t dstLinearIndex = dstIndices.front(); + for (int i = 0; i < numInputs; ++i) + srcLinearIndices[i] = srcIndices[i].front(); + + for (int dim = 1; dim < outputType.getRank(); ++dim) { + dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; + for (int i = 0; i < numInputs; ++i) + srcLinearIndices[i] = + srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; + } // Collect constant elements for all inputs at this loop iteration. - SmallVector intValues; - SmallVector fpValues; - if (elementType.isa()) { + if (isFloat) { for (int i = 0; i < numInputs; ++i) - fpValues.push_back(inputValues[i].getValue(srcIndices[i])); + computeFnInputs.apFloats[i] = + inputValues[i].getFlatValue(srcLinearIndices[i]); } else { for (int i = 0; i < numInputs; ++i) - intValues.push_back(inputValues[i].getValue(srcIndices[i])); + computeFnInputs.apInts[i] = + inputValues[i].getFlatValue(srcLinearIndices[i]); } // Invoke the computation to get the corresponding constant output // element. - APIntOrFloatArray inputs = {intValues, fpValues}; - APIntOrFloatArray outputs = computeFn(inputs); + APIntOrFloat outputs = computeFn(computeFnInputs); - if (elementType.isa()) { - fpOutputValues[linearIndex1] = outputs.apFloats.front(); + if (isFloat) { + fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue(); } else { - intOutputValues[linearIndex1] = outputs.apInts.front(); + intOutputValues[dstLinearIndex] = outputs.apInt.getValue(); } } DenseIntOrFPElementsAttr outputAttr; - if (elementType.isa()) { + if (isFloat) { outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues); } else { outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues); @@ -1494,7 +1523,11 @@ } // No computation; just return the orginal value. - return [](APIntOrFloatArray inputs) { return inputs; }; + return [](const APIntOrFloatArray &inputs) { + if (inputs.apFloats.empty()) + return APIntOrFloat{inputs.apInts.front(), llvm::None}; + return APIntOrFloat{llvm::None, inputs.apFloats.front()}; + }; } ControlElementwiseOpsFusionFn controlFn;