diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -65,6 +65,9 @@ /// work on primitive types, if possible. std::unique_ptr createLinalgDetensorizePass(); +/// Create a pass to DMA Linalg operations if possible. +std::unique_ptr createLinalgDMAAddressSpacePass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -145,4 +145,26 @@ ]; } +def LinalgDMAAddressSpace : Pass<"linalg-dma-address-space", "func::FuncOp"> { + let summary = "DMA linalg ops inputs/output from one address space to another"; + let description = [{ + Insert MemRef DMA start and wait operations before and after linalg generic ops. + Each input/output in `srcAddrSpace` will be DMA imported/exported into + `destAddrSpace`. This pass allows moving operation to run on faster memory. + + This pass supports linalg::generic operations only. Only input/output from + MemRef type with static dimentations are supported. + }]; + let constructor = "mlir::createLinalgDMAAddressSpacePass()"; + let dependentDialects = [ + "linalg::LinalgDialect", "memref::MemRefDialect" + ]; + let options = [ + Option<"srcAddrSpace", "src-addr-space", "unsigned", /*default=*/"0", + "Source inputs/output address space to DMA import/export.">, + Option<"destAddrSpace", "dest-addr-space", "unsigned", /*default=*/"0", + "Destination inputs/output address space to DMA import/export."> + ]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ DataLayoutPropagation.cpp DecomposeLinalgOps.cpp Detensorize.cpp + DMAAddressSpace.cpp DropUnitDims.cpp ElementwiseOpFusion.cpp ElementwiseToLinalg.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/DMAAddressSpace.cpp b/mlir/lib/Dialect/Linalg/Transforms/DMAAddressSpace.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/DMAAddressSpace.cpp @@ -0,0 +1,178 @@ +//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Passes.h" + +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Transforms/DialectConversion.h" + +#define DEBUG_TYPE "linalg-dma-address-space" + +namespace mlir { +#define GEN_PASS_DEF_LINALGDMAADDRESSSPACE +#include "mlir/Dialect/Linalg/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +/// A permissive pass that tries to DMA linalg operations that accepts +/// MemRefs in one address (memory) space and to DMA it to another address +/// (memory) space. Currently only linalg generic operations are supported. This +/// is necessary due to HWs where operations are allowed to run (or can benefit +/// from) running the operations on faster address (memory) space. +class LinalgDMAAddressSpacePass + : public impl::LinalgDMAAddressSpaceBase { +public: + LinalgDMAAddressSpacePass() = default; + + void runOnOperation() final; + +private: + // Tries to DMA the linalg generic operands to destination address (memory) + // space. + void tryDmaGenericOp(linalg::GenericOp genericOp); + // Checks if an a given operand can be DMA to destination address (memory) + // space. + bool canDMAOperand(Value operand); + // For inputs, allocates a memory before linalg generic operation and add + // MemRef DMA Start and MemRef DMA Wait operations for the given operand. For + // outputs, DMA Start and DMA Wait will be inserted after the linalg generic + // operation. + Value dmaOperand(linalg::GenericOp genericOp, Value operand, bool isInput); +}; +} // end anonymous namespace + +void LinalgDMAAddressSpacePass::runOnOperation() { + func::FuncOp funcOp = getOperation(); + + if (funcOp.isDeclaration()) { + LLVM_DEBUG(llvm::dbgs() + << "Skipping declaration function " << funcOp.getName() << "\n"); + return; + } + + if (srcAddrSpace == destAddrSpace) { + funcOp.emitError("Source and destination address spaces must be different"); + return signalPassFailure(); + } + + funcOp->walk( + [&](linalg::GenericOp genericOp) { tryDmaGenericOp(genericOp); }); +} + +void LinalgDMAAddressSpacePass::tryDmaGenericOp(linalg::GenericOp genericOp) { + LLVM_DEBUG(llvm::dbgs() << "Converting operands address spaces of " + << genericOp << "\n"); + + SmallVector newInputs; + for (auto operand : genericOp.getInputs()) { + Value newOperand; + if (canDMAOperand(operand)) + newOperand = dmaOperand(genericOp, operand, true /*isInput*/); + else + newOperand = operand; + + newInputs.push_back(newOperand); + } + genericOp.getInputsMutable().assign(newInputs); + + SmallVector newOutputs; + for (auto operand : genericOp.getOutputs()) { + Value newOperand; + if (canDMAOperand(operand)) + newOperand = dmaOperand(genericOp, operand, false /*isInput*/); + else + newOperand = operand; + + newOutputs.push_back(newOperand); + } + genericOp.getOutputsMutable().assign(newOutputs); +} + +bool LinalgDMAAddressSpacePass::canDMAOperand(Value operand) { + auto memRefType = dyn_cast(operand.getType()); + if (!memRefType) { + LLVM_DEBUG(llvm::dbgs() << "Only MemRef operands are supported for operand " + << operand << "\n"); + return false; + } + + unsigned addrSpace = 0 /*default*/; + Attribute spaceAttr = memRefType.getMemorySpace(); + if (spaceAttr) { + addrSpace = spaceAttr.cast().getInt(); + } + + if (addrSpace != srcAddrSpace) { + LLVM_DEBUG(llvm::dbgs() + << "Operand " << operand + << " address space doesn't match source address space " + << srcAddrSpace << "\n"); + return false; + } + if (!memRefType.hasStaticShape()) { + LLVM_DEBUG(llvm::dbgs() + << "Operand " << operand << " has dynamic dimensions\n"); + return false; + } + + return true; +} + +Value LinalgDMAAddressSpacePass::dmaOperand(linalg::GenericOp genericOp, + Value operand, bool isInput) { + OpBuilder builder(genericOp.getOperation()); + + builder.setInsertionPoint(genericOp); + auto loc = genericOp.getLoc(); + + // Create a tag (single element 1-d memref) for the DMA. + auto tagMemRefType = MemRefType::get({1}, builder.getIntegerType(32)); + auto tagBuffer = builder.create(loc, tagMemRefType); + + // Create a buffer in destination address space for the DMA. + auto srcMemRefType = cast(operand.getType()); + auto destMemRefType = + MemRefType::get(srcMemRefType.getShape(), srcMemRefType.getElementType(), + AffineMap{}, builder.getI64IntegerAttr(destAddrSpace)); + auto destBufferAlloc = builder.create(loc, destMemRefType); + + auto dmaSource = isInput ? operand : destBufferAlloc.getResult(); + auto dmaDest = isInput ? destBufferAlloc.getResult() : operand; + + if (!isInput) + // Four outputs only, the DMA should be after linalg generic + builder.setInsertionPointAfter(genericOp); + + // DMA the whole source buffer. + auto numElements = builder.create( + loc, builder.getIndexAttr((destMemRefType.getNumElements()))); + auto zero = builder.create(loc, builder.getIndexAttr(0)); + auto rank = srcMemRefType.getRank(); + SmallVector sourceIndices(rank, zero); + SmallVector destIndices(rank, zero); + SmallVector tagIndices(1, zero); + + // Create async DMA and wait right before input / right after output. + builder.create(loc, dmaSource, sourceIndices, dmaDest, + destIndices, numElements, tagBuffer, + tagIndices); + builder.create(loc, tagBuffer, tagIndices, numElements); + + return destBufferAlloc.getResult(); +} + +std::unique_ptr mlir::createLinalgDMAAddressSpacePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/dma_address_space.mlir b/mlir/test/Dialect/Linalg/dma_address_space.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/dma_address_space.mlir @@ -0,0 +1,198 @@ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(linalg-dma-address-space{src-addr-space=1 dest-addr-space=3}))" -split-input-file | FileCheck %s + + +// CHECK-LABEL: func.func @linalg_generic_dma_all_function_inputs_outputs( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x2x3x4xf32, 1>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x2x3x4xf32, 1>) -> memref<1x2x3x4xf32, 1> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 1> +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32> +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: %[[VAL_5:.*]] = arith.constant 24 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK: memref.dma_start %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]]], %[[VAL_4]]{{\[}}%[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]]], %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<1x2x3x4xf32, 1>, memref<1x2x3x4xf32, 3>, memref<1xi32> +// CHECK: memref.dma_wait %[[VAL_3]]{{\[}}%[[VAL_6]]], %[[VAL_5]] : memref<1xi32> +// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<1xi32> +// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: %[[VAL_9:.*]] = arith.constant 24 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 0 : index +// CHECK: memref.dma_start %[[VAL_1]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]]], %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]]], %[[VAL_9]], %[[VAL_7]]{{\[}}%[[VAL_10]]] : memref<1x2x3x4xf32, 1>, memref<1x2x3x4xf32, 3>, memref<1xi32> +// CHECK: memref.dma_wait %[[VAL_7]]{{\[}}%[[VAL_10]]], %[[VAL_9]] : memref<1xi32> +// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1xi32> +// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%[[VAL_4]], %[[VAL_8]] : memref<1x2x3x4xf32, 3>, memref<1x2x3x4xf32, 3>) outs(%[[VAL_12]] : memref<1x2x3x4xf32, 3>) { +// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: linalg.yield %[[VAL_16]] : f32 +// CHECK: } +// CHECK: %[[VAL_17:.*]] = arith.constant 24 : index +// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index +// CHECK: memref.dma_start %[[VAL_12]]{{\[}}%[[VAL_18]], %[[VAL_18]], %[[VAL_18]], %[[VAL_18]]], %[[VAL_2]]{{\[}}%[[VAL_18]], %[[VAL_18]], %[[VAL_18]], %[[VAL_18]]], %[[VAL_17]], %[[VAL_11]]{{\[}}%[[VAL_18]]] : memref<1x2x3x4xf32, 3>, memref<1x2x3x4xf32, 1>, memref<1xi32> +// CHECK: memref.dma_wait %[[VAL_11]]{{\[}}%[[VAL_18]]], %[[VAL_17]] : memref<1xi32> +// CHECK: return %[[VAL_2]] : memref<1x2x3x4xf32, 1> +// CHECK: } +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @linalg_generic_dma_all_function_inputs_outputs(%arg0: memref<1x2x3x4xf32, 1>, %arg1: memref<1x2x3x4xf32, 1>) -> memref<1x2x3x4xf32, 1> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 1> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<1x2x3x4xf32, 1>, memref<1x2x3x4xf32, 1>) outs(%alloc : memref<1x2x3x4xf32, 1>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_1 : f32 + linalg.yield %1 : f32 + } + return %alloc : memref<1x2x3x4xf32, 1> +} + +// ----- + +// CHECK-LABEL: func.func @linalg_generic_dma_ignore_inputs_addr_space( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x2x3x4xf32, 2>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x2x3x4xf32, 2>) -> memref<1x2x3x4xf32, 1> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 1> +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32> +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%[[VAL_0]], %[[VAL_1]] : memref<1x2x3x4xf32, 2>, memref<1x2x3x4xf32, 2>) outs(%[[VAL_4]] : memref<1x2x3x4xf32, 3>) { +// CHECK: ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32): +// CHECK: %[[VAL_8:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32 +// CHECK: linalg.yield %[[VAL_8]] : f32 +// CHECK: } +// CHECK: %[[VAL_9:.*]] = arith.constant 24 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 0 : index +// CHECK: memref.dma_start %[[VAL_4]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]]], %[[VAL_2]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]]], %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_10]]] : memref<1x2x3x4xf32, 3>, memref<1x2x3x4xf32, 1>, memref<1xi32> +// CHECK: memref.dma_wait %[[VAL_3]]{{\[}}%[[VAL_10]]], %[[VAL_9]] : memref<1xi32> +// CHECK: return %[[VAL_2]] : memref<1x2x3x4xf32, 1> +// CHECK: } +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @linalg_generic_dma_ignore_inputs_addr_space(%arg0: memref<1x2x3x4xf32, 2>, %arg1: memref<1x2x3x4xf32, 2>) -> memref<1x2x3x4xf32, 1> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 1> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<1x2x3x4xf32, 2>, memref<1x2x3x4xf32, 2>) outs(%alloc : memref<1x2x3x4xf32, 1>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_1 : f32 + linalg.yield %1 : f32 + } + return %alloc : memref<1x2x3x4xf32, 1> +} + +// ----- + +// CHECK-LABEL: func.func @linalg_generic_dma_ignore_outputs_addr_space( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x2x3x4xf32, 1>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x2x3x4xf32, 1>) -> memref<1x2x3x4xf32, 3> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 3> +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32> +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: %[[VAL_5:.*]] = arith.constant 24 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK: memref.dma_start %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]]], %[[VAL_4]]{{\[}}%[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]]], %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<1x2x3x4xf32, 1>, memref<1x2x3x4xf32, 3>, memref<1xi32> +// CHECK: memref.dma_wait %[[VAL_3]]{{\[}}%[[VAL_6]]], %[[VAL_5]] : memref<1xi32> +// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<1xi32> +// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: %[[VAL_9:.*]] = arith.constant 24 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 0 : index +// CHECK: memref.dma_start %[[VAL_1]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]]], %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]]], %[[VAL_9]], %[[VAL_7]]{{\[}}%[[VAL_10]]] : memref<1x2x3x4xf32, 1>, memref<1x2x3x4xf32, 3>, memref<1xi32> +// CHECK: memref.dma_wait %[[VAL_7]]{{\[}}%[[VAL_10]]], %[[VAL_9]] : memref<1xi32> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%[[VAL_4]], %[[VAL_8]] : memref<1x2x3x4xf32, 3>, memref<1x2x3x4xf32, 3>) outs(%[[VAL_2]] : memref<1x2x3x4xf32, 3>) { +// CHECK: ^bb0(%[[VAL_11:.*]]: f32, %[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32): +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_11]], %[[VAL_12]] : f32 +// CHECK: linalg.yield %[[VAL_14]] : f32 +// CHECK: } +// CHECK: return %[[VAL_2]] : memref<1x2x3x4xf32, 3> +// CHECK: } +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @linalg_generic_dma_ignore_outputs_addr_space(%arg0: memref<1x2x3x4xf32, 1>, %arg1: memref<1x2x3x4xf32, 1>) -> memref<1x2x3x4xf32, 3> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 3> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<1x2x3x4xf32, 1>, memref<1x2x3x4xf32, 1>) outs(%alloc : memref<1x2x3x4xf32, 3>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_1 : f32 + linalg.yield %1 : f32 + } + return %alloc : memref<1x2x3x4xf32, 3> +} + +// ----- + +// CHECK-LABEL: func.func @linalg_generic_dma_ignore_all( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x2x3x4xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x2x3x4xf32>) -> memref<1x2x3x4xf32, 3> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 3> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%[[VAL_0]], %[[VAL_1]] : memref<1x2x3x4xf32>, memref<1x2x3x4xf32>) outs(%[[VAL_2]] : memref<1x2x3x4xf32, 3>) { +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32): +// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: linalg.yield %[[VAL_6]] : f32 +// CHECK: } +// CHECK: return %[[VAL_2]] : memref<1x2x3x4xf32, 3> +// CHECK: } +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @linalg_generic_dma_ignore_all(%arg0: memref<1x2x3x4xf32>, %arg1: memref<1x2x3x4xf32>) -> memref<1x2x3x4xf32, 3> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 3> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<1x2x3x4xf32>, memref<1x2x3x4xf32>) outs(%alloc : memref<1x2x3x4xf32, 3>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_1 : f32 + linalg.yield %1 : f32 + } + return %alloc : memref<1x2x3x4xf32, 3> +} + +// ----- + +// CHECK-LABEL: func.func @linalg_generic_dma_only_one_input( +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x4xi32, 1>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<3x4xi32>) -> memref<3x4xi32, 3> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x4xi32, 3> +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32> +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<3x4xi32, 3> +// CHECK: %[[VAL_5:.*]] = arith.constant 12 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK: memref.dma_start %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_4]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<3x4xi32, 1>, memref<3x4xi32, 3>, memref<1xi32> +// CHECK: memref.dma_wait %[[VAL_3]]{{\[}}%[[VAL_6]]], %[[VAL_5]] : memref<1xi32> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_4]], %[[VAL_1]] : memref<3x4xi32, 3>, memref<3x4xi32>) outs(%[[VAL_2]] : memref<3x4xi32, 3>) { +// CHECK: ^bb0(%[[VAL_7:.*]]: i32, %[[VAL_8:.*]]: i32, %[[VAL_9:.*]]: i32): +// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_7]], %[[VAL_8]] : i32 +// CHECK: linalg.yield %[[VAL_10]] : i32 +// CHECK: } +// CHECK: return %[[VAL_2]] : memref<3x4xi32, 3> +// CHECK: } +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @linalg_generic_dma_only_one_input(%arg0: memref<3x4xi32, 1>, %arg1: memref<3x4xi32>) -> memref<3x4xi32, 3> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<3x4xi32, 3> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<3x4xi32, 1>, memref<3x4xi32>) outs(%alloc : memref<3x4xi32, 3>) { + ^bb0(%in: i32, %in_1: i32, %out: i32): + %1 = arith.addi %in, %in_1 : i32 + linalg.yield %1 : i32 + } + return %alloc : memref<3x4xi32, 3> +} + +// ----- + +// CHECK-LABEL: func.func @linalg_generic_ignore_dyanmic_shape( +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x4xi32, 1>, +// CHECK-SAME: %[[VAL_1:.*]]: memref) -> memref<3x4xi32, 1> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x4xi32, 1> +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32> +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<3x4xi32, 3> +// CHECK: %[[VAL_5:.*]] = arith.constant 12 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK: memref.dma_start %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_4]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<3x4xi32, 1>, memref<3x4xi32, 3>, memref<1xi32> +// CHECK: memref.dma_wait %[[VAL_3]]{{\[}}%[[VAL_6]]], %[[VAL_5]] : memref<1xi32> +// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<1xi32> +// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<3x4xi32, 3> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_4]], %[[VAL_1]] : memref<3x4xi32, 3>, memref) outs(%[[VAL_8]] : memref<3x4xi32, 3>) { +// CHECK: ^bb0(%[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: i32, %[[VAL_11:.*]]: i32): +// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_9]], %[[VAL_10]] : i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } +// CHECK: %[[VAL_13:.*]] = arith.constant 12 : index +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index +// CHECK: memref.dma_start %[[VAL_8]]{{\[}}%[[VAL_14]], %[[VAL_14]]], %[[VAL_2]]{{\[}}%[[VAL_14]], %[[VAL_14]]], %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref<3x4xi32, 3>, memref<3x4xi32, 1>, memref<1xi32> +// CHECK: memref.dma_wait %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_13]] : memref<1xi32> +// CHECK: return %[[VAL_2]] : memref<3x4xi32, 1> +// CHECK: } +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @linalg_generic_ignore_dyanmic_shape(%arg0: memref<3x4xi32, 1>, %arg1: memref) -> memref<3x4xi32, 1> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<3x4xi32, 1> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<3x4xi32, 1>, memref) outs(%alloc : memref<3x4xi32, 1>) { + ^bb0(%in: i32, %in_1: i32, %out: i32): + %1 = arith.addi %in, %in_1 : i32 + linalg.yield %1 : i32 + } + return %alloc : memref<3x4xi32, 1> +} diff --git a/mlir/test/Dialect/Linalg/dma_address_space_same_addr_space.mlir b/mlir/test/Dialect/Linalg/dma_address_space_same_addr_space.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/dma_address_space_same_addr_space.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(linalg-dma-address-space{src-addr-space=3 dest-addr-space=3}))" -verify-diagnostics + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// expected-error @+1 {{Source and destination address spaces must be different}} +func.func @linalg_generic_dma_ignore_all(%arg0: memref<1x2x3x4xf32>, %arg1: memref<1x2x3x4xf32>) -> memref<1x2x3x4xf32, 3> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 3> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<1x2x3x4xf32>, memref<1x2x3x4xf32>) outs(%alloc : memref<1x2x3x4xf32, 3>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_1 : f32 + linalg.yield %1 : f32 + } + return %alloc : memref<1x2x3x4xf32, 3> +}