diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -44,22 +44,19 @@ llvm::DenseMap tileToPointMapping; // The permutation of outer dims (on domain). SmallVector outerDimsOnDomainPerm; - std::optional paddingValue; }; -static PackInfo getPackingInfoFromConsumer( - AffineMap indexingMap, ArrayRef innerTileSizes, - ArrayRef innerDimsPos, ArrayRef outerDimsPerm, - std::optional paddingValue = std::nullopt) { +static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap, + tensor::PackOp packOp) { LLVM_DEBUG( { llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; }); PackInfo packInfo; - packInfo.paddingValue = paddingValue; int64_t origNumDims = indexingMap.getNumDims(); SmallVector exprs(indexingMap.getResults()); + ArrayRef innerDimsPos = packOp.getInnerDimsPos(); for (auto [index, innerDimPos, tileSize] : llvm::zip_equal(llvm::seq(0, innerDimsPos.size()), - innerDimsPos, innerTileSizes)) { + innerDimsPos, packOp.getMixedTiles())) { int64_t domainDimPos = exprs[innerDimPos].cast().getPosition(); packInfo.tiledDimsPos.push_back(domainDimPos); @@ -74,7 +71,7 @@ }); } - for (auto dim : outerDimsPerm) + for (auto dim : packOp.getOuterDimsPerm()) packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim)); if (!packInfo.outerDimsOnDomainPerm.empty()) { LLVM_DEBUG({ @@ -208,7 +205,7 @@ b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); auto packedOperand = b.create( loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, - packInfo.paddingValue, outerDimsPerm); + /*padding=*/std::nullopt, outerDimsPerm); return std::make_tuple(packedOperand, indexingMap); } @@ -279,9 +276,7 @@ OpOperand *opOperand = genericOp.getDpsInitOperand(0); auto packInfo = getPackingInfoFromConsumer( - genericOp.getMatchingIndexingMap(opOperand), packOp.getMixedTiles(), - packOp.getInnerDimsPos(), packOp.getOuterDimsPerm(), - packOp.getPaddingValue()); + genericOp.getMatchingIndexingMap(opOperand), packOp); Location loc = packOp.getLoc(); SmallVector inputOperands;