diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -843,7 +843,8 @@ DefaultValuedAttr:$use_full_tile_buffers, UnitAttr:$use_full_tiles_by_default, UnitAttr:$use_alloca, - OptionalAttr:$mapping, + ConfinedAttr, [ArrayMaxCount<1>]>:$mapping, + OptionalAttr:$copy_permutation, OptionalAttr:$alignment); let results = (outs PDL_Operation:$transformed); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -177,8 +177,8 @@ /// dimension. If that is not possible, contains the dynamic size of the /// subview. The call back should return the buffer to use. using AllocBufferCallbackFn = std::function( - OpBuilder &b, memref::SubViewOp subView, - ArrayRef boundingSubViewSize, DataLayout &layout)>; + OpBuilder &b, Operation *subViewOp, ArrayRef boundingSubViewSize, + DataLayout &layout)>; /// Callback function type used to deallocate the buffers used to hold the /// promoted subview. @@ -256,6 +256,13 @@ copyOutFn = copyOut; return *this; } + + ArrayRef copyPermutation; + LinalgPromotionOptions & + setCopyPermutation(ArrayRef permutationArrayRef) { + copyPermutation = permutationArrayRef; + return *this; + } }; /// Split Reduction options. @@ -507,16 +514,19 @@ /// Create a new buffer using the `allocationFn` provided. The size of this /// buffer is the smallest constant bounding size along each dimension that /// can be computed for the size of the result of `subView`. Returns the -/// allocated buffer as `fullLocalView` and the view that matches the size of -/// the result of subview operation as `partialLocalView`. +/// allocated buffer as `fullLocalView`, the view that matches the size of +/// the result of subview operation as `partialLocalView`, and the input or +/// resulted operation from the promotion as`referenceOp` that can be used +/// subsequent steps. struct PromotionInfo { Value fullLocalView; Value partialLocalView; + Value referenceOp; }; FailureOr promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView, const AllocBufferCallbackFn &allocationFn, - DataLayout &layout); + DataLayout &layout, ArrayRef permutation); /// Promote the `subViews` into a new buffer allocated at the insertion point /// `b`. Promotion occurs in 3 steps: @@ -531,10 +541,9 @@ const LinalgPromotionOptions &options); /// Allocate the subview in the GPU workgroup memory. -std::optional allocateWorkgroupMemory(OpBuilder &builder, - memref::SubViewOp subview, - ArrayRef sizeBounds, - DataLayout &); +Optional allocateWorkgroupMemory(OpBuilder &builder, Operation *subview, + ArrayRef sizeBounds, + DataLayout &); /// In case of GPU group memory there is no need to deallocate. LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value /*buffer*/); @@ -544,10 +553,9 @@ LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst); /// Allocate the subview in the GPU private memory. -std::optional allocateGPUPrivateMemory(OpBuilder &builder, - memref::SubViewOp subview, - ArrayRef sizeBounds, - DataLayout &); +Optional allocateGPUPrivateMemory(OpBuilder &builder, Operation *subview, + ArrayRef sizeBounds, + DataLayout &); /// Normal copy to between src and dst. LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1694,6 +1694,10 @@ CPred<"$_self.cast<::mlir::ArrayAttr>().size() >= " # n>, "with at least " # n # " elements">; +class ArrayMaxCount : AttrConstraint< + CPred<"$_self.cast<::mlir::ArrayAttr>().size() <= " # n>, + "with at most " # n # " elements">; + class ArrayCount : AttrConstraint< CPred<"$_self.cast<::mlir::ArrayAttr>().size() == " #n>, "with exactly " # n # " elements">; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1870,6 +1870,9 @@ } else { return emitDefaultDefiniteFailure(target); } + if (getCopyPermutation().has_value()) { + promotionOptions.setCopyPermutation(*getCopyPermutation()); + } } if (failed(promoteSubviewsPrecondition(target, promotionOptions))) @@ -1880,7 +1883,7 @@ FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) return emitDefaultDefiniteFailure(target); - results.push_back(target); + results.push_back(*res); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -80,11 +80,17 @@ /// memref<..xi8> and return a view to get a memref type of shape /// boundingSubViewSize. static std::optional defaultAllocBufferCallBack( - const LinalgPromotionOptions &options, OpBuilder &builder, - memref::SubViewOp subView, ArrayRef boundingSubViewSize, - std::optional alignment, DataLayout &layout) { - ShapedType viewType = subView.getType(); - ImplicitLocOpBuilder b(subView.getLoc(), builder); + const LinalgPromotionOptions &options, OpBuilder &builder, Operation *op, + ArrayRef boundingSubViewSize, std::optional alignment, + DataLayout &layout) { + auto viewType = llvm::TypeSwitch(op) + .Case( + [](auto casted) { return casted.getType(); }) + .Default([](Operation *) { return nullptr; }); + if (!viewType) + return std::nullopt; + + ImplicitLocOpBuilder b(op->getLoc(), builder); auto zero = b.createOrFold(0); auto one = b.createOrFold(1); @@ -138,6 +144,7 @@ /// Alignment of promoted buffer. std::optional alignment; + ArrayRef copyPermutation; }; } // namespace @@ -165,7 +172,7 @@ if (options.allocationFn) { allocationFn = *options.allocationFn; } else { - allocationFn = [&](OpBuilder &b, memref::SubViewOp subViewOp, + allocationFn = [&](OpBuilder &b, Operation *subViewOp, ArrayRef boundingSubViewSize, DataLayout &layout) -> std::optional { return defaultAllocBufferCallBack(options, b, subViewOp, @@ -190,6 +197,7 @@ }; copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack); copyOutFn = (options.copyOutFn ? *(options.copyOutFn) : defaultCopyCallBack); + copyPermutation = options.copyPermutation; } // Performs promotion of a `subView` into a local buffer of the size of the @@ -197,12 +205,15 @@ // than the actual size of the `subView` at the boundaries. // This is related to the full/partial tile problem. // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and -// `partialLocalView` such that: +// `partialLocalView` and `referenceOp` such that: // * `buffer` is always the size of the full tile. // * `fullLocalView` is a dense contiguous view into that buffer. // * `partialLocalView` is a dense non-contiguous slice of `fullLocalView` // that corresponds to the size of `subView` and accounting for boundary // effects. +// * `referenceOp` which can be the original subviewOp or created +// memref:TransposeOp +// // The point of the full tile buffer is that constant static tile sizes are // folded and result in a buffer type with statically known size and alignment // properties. @@ -211,13 +222,22 @@ // by a partial `copy` op. FailureOr mlir::linalg::promoteSubviewAsNewBuffer( OpBuilder &b, Location loc, memref::SubViewOp subView, - const AllocBufferCallbackFn &allocationFn, DataLayout &layout) { + const AllocBufferCallbackFn &allocationFn, DataLayout &layout, + ArrayRef permutation) { auto viewType = subView.getType(); auto rank = viewType.getRank(); - SmallVector fullSizes; + SmallVector fullSizes; SmallVector partialSizes; fullSizes.reserve(rank); partialSizes.reserve(rank); + // Get identity map. + AffineMap permutationMap = + AffineMap::getMultiDimIdentityMap(rank, b.getContext()); + // If permutation is given update the permutation map. + if (permutation.size() > 0) { + permutationMap = AffineMap::getPermutationMap( + llvm::to_vector_of(permutation), b.getContext()); + } llvm::SmallBitVector droppedDims = subView.getDroppedDims(); int64_t resultDimIdx = 0; for (const auto &en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) { @@ -227,16 +247,12 @@ // Try to extract a tight constant. If the size is known statically, no need // to look for the bound. LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n"); - Value size; - if (auto attr = rangeValue.size.dyn_cast()) { - size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); - } else { - Value materializedSize = - getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); - FailureOr upperBound = - getConstantUpperBoundForIndex(materializedSize); + Value size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); + auto attr = rangeValue.size.dyn_cast(); + if (!attr) { + FailureOr upperBound = getConstantUpperBoundForIndex(size); size = failed(upperBound) - ? materializedSize + ? size : b.create(loc, *upperBound); } LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); @@ -244,18 +260,29 @@ partialSizes.push_back( b.createOrFold(loc, subView, resultDimIdx++)); } - SmallVector dynSizes(fullSizes.size(), ShapedType::kDynamic); // If a callback is not specified, then use the default implementation for // allocating the promoted buffer. - std::optional fullLocalView = - allocationFn(b, subView, fullSizes, layout); + std::optional fullLocalView; + Value referenceOp = subView; + if (permutationMap.isIdentity()) { + fullLocalView = allocationFn(b, subView, fullSizes, layout); + } else { + auto transposeOp = b.create( + loc, subView, AffineMapAttr::get(permutationMap)); + referenceOp = transposeOp; + fullLocalView = allocationFn( + b, transposeOp, + applyPermutationMap(permutationMap, ArrayRef(fullSizes)), layout); + } + if (!fullLocalView) return failure(); SmallVector zeros(fullSizes.size(), b.getIndexAttr(0)); SmallVector ones(fullSizes.size(), b.getIndexAttr(1)); auto partialLocalView = b.createOrFold( - loc, *fullLocalView, zeros, partialSizes, ones); - return PromotionInfo{*fullLocalView, partialLocalView}; + loc, *fullLocalView, zeros, + applyPermutationMap(permutationMap, ArrayRef(partialSizes)), ones); + return PromotionInfo{*fullLocalView, partialLocalView, referenceOp}; } static FailureOr> @@ -269,8 +296,9 @@ for (auto v : options.subViews) { memref::SubViewOp subView = cast(v.second.getDefiningOp()); - auto promotionInfo = promoteSubviewAsNewBuffer( - b, b.getLoc(), subView, options.allocationFn, layout); + auto promotionInfo = + promoteSubviewAsNewBuffer(b, b.getLoc(), subView, options.allocationFn, + layout, options.copyPermutation); if (failed(promotionInfo)) return failure(); promotionInfoMap[v.first] = *promotionInfo; @@ -306,9 +334,8 @@ auto info = promotionInfoMap.find(v.first); if (info == promotionInfoMap.end()) continue; - if (failed(options.copyInFn( - b, cast(v.second.getDefiningOp()), - info->second.partialLocalView))) + if (failed(options.copyInFn(b, info->second.referenceOp, + info->second.partialLocalView))) return failure(); } return promotionInfoMap; @@ -332,6 +359,7 @@ opViews.reserve(op->getNumOperands()); SmallVector, 8> writebackViews; writebackViews.reserve(promotedBuffersAndViews->size()); + SmallVector indexingMaps = op.getIndexingMapsArray(); for (OpOperand &opOperand : op->getOpOperands()) { int64_t operandNumber = opOperand.getOperandNumber(); if (options.subViews.count(operandNumber) != 0) { @@ -343,13 +371,41 @@ (*promotedBuffersAndViews)[operandNumber].partialLocalView); if (operandNumber >= op.getNumDpsInputs()) writebackViews.emplace_back(std::make_pair( - opOperand.get(), + (*promotedBuffersAndViews)[operandNumber].referenceOp, (*promotedBuffersAndViews)[operandNumber].partialLocalView)); + + // 2.1 Get the identity affine map. + AffineMap permutationMap = AffineMap::getMultiDimIdentityMap( + opOperand.get().getType().cast().getRank(), + b.getContext()); + + // 2.2 Apply the given copy permutation to the original indexing map. + if (options.copyPermutation.size() > 0) { + permutationMap = AffineMap::getPermutationMap( + llvm::to_vector_of(options.copyPermutation), + b.getContext()); + } + AffineMap transposedMap = + permutationMap.compose(op.getMatchingIndexingMap(&opOperand)); + indexingMaps[op.getIndexingMapIndex(&opOperand)] = transposedMap; } else { opViews.push_back(opOperand.get()); } } op->setOperands(0, opViews.size(), opViews); + linalg::GenericOp transposedGenericOp; + // 2.3 If the copy permutation is given replace the current Linalg op with a + // linalg.generic. + if (options.copyPermutation.size() > 0) { + ValueRange operandsRef(op->getOperands()); + transposedGenericOp = b.create( + /*location=*/op->getLoc(), + /*inputs=*/operandsRef.take_front(op.getNumDpsInputs()), + /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/op.getIteratorTypesArray()); + transposedGenericOp.getRegion().takeBody(op->getRegion(0)); + } OpBuilder::InsertionGuard guard(b); b.setInsertionPointAfter(op); @@ -363,6 +419,14 @@ // 4. Dealloc all local buffers. for (const auto &pi : *promotedBuffersAndViews) (void)options.deallocationFn(b, pi.second.fullLocalView); + + // 5. If the copy permutation is given replace the current Linalg op with + // created linalg.generic. + if (options.copyPermutation.size() > 0) { + IRRewriter rewriter(b); + rewriter.replaceOp(op, transposedGenericOp->getResults()); + return cast(*transposedGenericOp); + } return op; } @@ -403,13 +467,24 @@ /// Allocate the given subview to a memory address space in GPU by creating a /// allocation operation and setting the memref type address space to desired /// address space. -static std::optional allocateSubviewGPUMemoryInAddressSpace( - OpBuilder &builder, memref::SubViewOp subview, ArrayRef sizeBounds, - gpu::AddressSpace addressSpace) { +static Optional +allocateSubviewGPUMemoryInAddressSpace(OpBuilder &builder, Operation *op, + ArrayRef sizeBounds, + gpu::AddressSpace addressSpace) { OpBuilder::InsertionGuard guard(builder); - - func::FuncOp funcOp = subview->getParentOfType(); - if (!funcOp) + auto subview = llvm::TypeSwitch(op) + .Case( + [](auto casted) { return casted.getType(); }) + .Default([](Operation *) { return nullptr; }); + + auto funcOp = + llvm::TypeSwitch(op) + .Case([](auto casted) { + return casted->template getParentOfType(); + }) + .Default([](Operation *) { return nullptr; }); + + if (!funcOp || !subview) return std::nullopt; // The subview size bounds are expected to be constant; they specify the shape @@ -424,7 +499,7 @@ builder.setInsertionPoint(&funcOp.front(), funcOp.front().begin()); auto type = MemRefType::get( - shape, subview.getType().getElementType(), MemRefLayoutAttrInterface{}, + shape, subview.getElementType(), MemRefLayoutAttrInterface{}, gpu::AddressSpaceAttr::get(builder.getContext(), addressSpace)); Value buffer; if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) { @@ -438,9 +513,10 @@ } /// Allocate the subview in the GPU workgroup memory. -std::optional mlir::linalg::allocateWorkgroupMemory( - OpBuilder &builder, memref::SubViewOp subview, ArrayRef sizeBounds, - DataLayout &) { +Optional +mlir::linalg::allocateWorkgroupMemory(OpBuilder &builder, Operation *subview, + ArrayRef sizeBounds, + DataLayout &) { return allocateSubviewGPUMemoryInAddressSpace( builder, subview, sizeBounds, gpu::GPUDialect::getWorkgroupAddressSpace()); @@ -463,9 +539,10 @@ } /// Allocate the subview in the GPU private memory. -std::optional mlir::linalg::allocateGPUPrivateMemory( - OpBuilder &builder, memref::SubViewOp subview, ArrayRef sizeBounds, - DataLayout &) { +Optional +mlir::linalg::allocateGPUPrivateMemory(OpBuilder &builder, Operation *subview, + ArrayRef sizeBounds, + DataLayout &) { return allocateSubviewGPUMemoryInAddressSpace( builder, subview, sizeBounds, gpu::GPUDialect::getPrivateAddressSpace()); } diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -275,3 +275,52 @@ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!pdl.operation) -> !pdl.operation %1 = transform.structured.promote %0 } + +// ----- +func.func @gemm_transposed(%a : memref, %b : memref, %c : memref) +{ + linalg.matmul ins(%a, %b: memref, memref) + outs(%c: memref) + return +} + +// CHECK-LABEL: func @gemm_transposed +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK: %[[alloc_A:.*]] = memref.alloc() : memref<16x4xf32, #gpu.address_space> +// CHECK: %[[alloc_B:.*]] = memref.alloc() : memref<8x4xf32, #gpu.address_space> +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[subview_A:.*]] = memref.subview {{.*}} +// CHECK: %[[subview_B:.*]] = memref.subview {{.*}} +// CHECK: %[[subview_C:.*]] = memref.subview {{.*}} + +// CHECK: %[[transposed_B:.*]] = memref.transpose %[[subview_B]] (d0, d1) -> (d1, d0) : memref> to memref> +// CHECK: %[[shared_B:.*]] = memref.subview %[[alloc_B]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<8x4xf32, #gpu.address_space> to memref, #gpu.address_space> + +// CHECK-NEXT: gpu.barrier +// CHECK-NEXT: memref.copy %[[transposed_B]], %[[shared_B]] : memref> to memref, #gpu.address_space> +// CHECK-NEXT: gpu.barrier + +// CHECK: %[[shared_A:.*]] = memref.subview %[[alloc_A]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x4xf32, #gpu.address_space> to memref, #gpu.address_space> + + +// CHECK-NEXT: gpu.barrier +// CHECK-NEXT: memref.copy %[[subview_A]], %[[shared_A]] : memref> to memref, #gpu.address_space> +// CHECK-NEXT: gpu.barrier + +// CHECK: linalg.generic {{.+}} ins(%[[shared_A]], %[[shared_B]]{{.*}} outs(%[[subview_C]] + + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1, %loops:3 = transform.structured.tile %0 [16, 8, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) + %2 = transform.structured.promote %1 { operands_to_promote = [1], mapping = [#gpu.memory_space], copy_permutation = array } + %3 = transform.structured.promote %2 { operands_to_promote = [0], mapping = [#gpu.memory_space]} +}