diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -65,6 +65,9 @@ /// work on primitive types, if possible. std::unique_ptr createLinalgDetensorizePass(); +/// Create a pass to update address space of Linalg operations if possible. +std::unique_ptr createLinalgUpdateAddressSpacePass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -145,4 +145,27 @@ ]; } +def LinalgUpdateAddressSpace : Pass<"linalg-update-address-space", "func::FuncOp"> { + let summary = "Update linalg ops inputs/output from one address space to another"; + let description = [{ + Insert MemRef copy operations before and after linalg generic ops. + Each input/output in `srcAddrSpace` will be copied into `destAddrSpace`. + + This pass allows moving operation to run on faster memory. + + This pass supports linalg::generic operations only. Only input/output from + MemRef type with static dimentations are supported. + }]; + let constructor = "mlir::createLinalgUpdateAddressSpacePass()"; + let dependentDialects = [ + "linalg::LinalgDialect", "memref::MemRefDialect" + ]; + let options = [ + Option<"srcAddrSpace", "src-addr-space", "unsigned", /*default=*/"0", + "Source inputs/output address space to copy from/to.">, + Option<"destAddrSpace", "dest-addr-space", "unsigned", /*default=*/"0", + "Destination inputs/output address space to copy from/to."> + ]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ DataLayoutPropagation.cpp DecomposeLinalgOps.cpp Detensorize.cpp + UpdateAddressSpace.cpp DropUnitDims.cpp ElementwiseOpFusion.cpp ElementwiseToLinalg.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/UpdateAddressSpace.cpp b/mlir/lib/Dialect/Linalg/Transforms/UpdateAddressSpace.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/UpdateAddressSpace.cpp @@ -0,0 +1,166 @@ +//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Passes.h" + +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Transforms/DialectConversion.h" + +#define DEBUG_TYPE "linalg-update-address-space" + +namespace mlir { +#define GEN_PASS_DEF_LINALGUPDATEADDRESSSPACE +#include "mlir/Dialect/Linalg/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +/// A permissive pass that tries to update linalg operations that accepts +/// MemRefs in one address (memory) space and to copy it to another address +/// (memory) space. Currently only linalg generic operations are supported. This +/// is necessary due to HWs where operations are allowed to run (or can benefit +/// from) running the operations on faster address (memory) space. +class LinalgUpdateAddressSpacePass + : public impl::LinalgUpdateAddressSpaceBase { +public: + LinalgUpdateAddressSpacePass() = default; + + void runOnOperation() final; + +private: + // Tries to update the linalg generic operands/results to destination address + // (memory) space. + void tryUpdateAddressSpaceGenericOp(linalg::GenericOp genericOp); + // Checks if an a given operand can be updated to destination address (memory) + // space. + bool canUpdateAddressSpaceOperand(Value operand); + // For inputs, allocates a memory before linalg generic operation and add + // MemRef Copy operations for the given operand. For outputs, MemRef copy + // operations will be inserted after the linalg generic operation. + Value updateAddressSpaceOperand(linalg::GenericOp genericOp, Value operand, + bool isInput); +}; +} // end anonymous namespace + +void LinalgUpdateAddressSpacePass::runOnOperation() { + func::FuncOp funcOp = getOperation(); + + if (funcOp.isDeclaration()) { + LLVM_DEBUG(llvm::dbgs() + << "Skipping declaration function " << funcOp.getName() << "\n"); + return; + } + + if (srcAddrSpace == destAddrSpace) { + funcOp.emitError("Source and destination address spaces must be different"); + return signalPassFailure(); + } + + funcOp->walk([&](linalg::GenericOp genericOp) { + tryUpdateAddressSpaceGenericOp(genericOp); + }); +} + +void LinalgUpdateAddressSpacePass::tryUpdateAddressSpaceGenericOp( + linalg::GenericOp genericOp) { + LLVM_DEBUG(llvm::dbgs() << "Converting operands address spaces of " + << genericOp << "\n"); + + SmallVector newInputs; + for (auto operand : genericOp.getInputs()) { + Value newOperand; + if (canUpdateAddressSpaceOperand(operand)) + newOperand = + updateAddressSpaceOperand(genericOp, operand, true /*isInput*/); + else + newOperand = operand; + + newInputs.push_back(newOperand); + } + genericOp.getInputsMutable().assign(newInputs); + + SmallVector newOutputs; + for (auto operand : genericOp.getOutputs()) { + Value newOperand; + if (canUpdateAddressSpaceOperand(operand)) + newOperand = + updateAddressSpaceOperand(genericOp, operand, false /*isInput*/); + else + newOperand = operand; + + newOutputs.push_back(newOperand); + } + genericOp.getOutputsMutable().assign(newOutputs); +} + +bool LinalgUpdateAddressSpacePass::canUpdateAddressSpaceOperand(Value operand) { + auto memRefType = dyn_cast(operand.getType()); + if (!memRefType) { + LLVM_DEBUG(llvm::dbgs() << "Only MemRef operands are supported for operand " + << operand << "\n"); + return false; + } + + unsigned addrSpace = 0 /*default*/; + Attribute spaceAttr = memRefType.getMemorySpace(); + if (spaceAttr) { + addrSpace = spaceAttr.cast().getInt(); + } + + if (addrSpace != srcAddrSpace) { + LLVM_DEBUG(llvm::dbgs() + << "Operand " << operand + << " address space doesn't match source address space " + << srcAddrSpace << "\n"); + return false; + } + if (!memRefType.hasStaticShape()) { + LLVM_DEBUG(llvm::dbgs() + << "Operand " << operand << " has dynamic dimensions\n"); + return false; + } + + return true; +} + +Value LinalgUpdateAddressSpacePass::updateAddressSpaceOperand( + linalg::GenericOp genericOp, Value operand, bool isInput) { + OpBuilder builder(genericOp.getOperation()); + + builder.setInsertionPoint(genericOp); + auto loc = genericOp.getLoc(); + + // Create a buffer in destination address space for the MemRef Copy. + auto srcMemRefType = cast(operand.getType()); + auto destMemRefType = + MemRefType::get(srcMemRefType.getShape(), srcMemRefType.getElementType(), + AffineMap{}, builder.getI64IntegerAttr(destAddrSpace)); + auto destBufferAlloc = builder.create(loc, destMemRefType); + + auto srcMemRef = isInput ? operand : destBufferAlloc.getResult(); + auto destMemRef = isInput ? destBufferAlloc.getResult() : operand; + + if (!isInput) + // Four outputs only, the MemRef copy should be after linalg generic + builder.setInsertionPointAfter(genericOp); + + // Create MemRef copy right before input / right after output. + builder.create(loc, srcMemRef, destMemRef); + + return destBufferAlloc.getResult(); +} + +std::unique_ptr mlir::createLinalgUpdateAddressSpacePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/update_address_address_space.mlir b/mlir/test/Dialect/Linalg/update_address_address_space.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/update_address_address_space.mlir @@ -0,0 +1,162 @@ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(linalg-update-address-space{src-addr-space=1 dest-addr-space=3}))" -split-input-file | FileCheck %s + + +// CHECK-LABEL: func.func @linalg_generic_update_all_function_inputs_outputs( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x2x3x4xf32, 1>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x2x3x4xf32, 1>) -> memref<1x2x3x4xf32, 1> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 1> +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: memref.copy %[[VAL_0]], %[[VAL_3]] : memref<1x2x3x4xf32, 1> to memref<1x2x3x4xf32, 3> +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: memref.copy %[[VAL_1]], %[[VAL_4]] : memref<1x2x3x4xf32, 1> to memref<1x2x3x4xf32, 3> +// CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x2x3x4xf32, 3>, memref<1x2x3x4xf32, 3>) outs(%[[VAL_5]] : memref<1x2x3x4xf32, 3>) { +// CHECK: ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32 +// CHECK: linalg.yield %[[VAL_9]] : f32 +// CHECK: } +// CHECK: memref.copy %[[VAL_5]], %[[VAL_2]] : memref<1x2x3x4xf32, 3> to memref<1x2x3x4xf32, 1> +// CHECK: return %[[VAL_2]] : memref<1x2x3x4xf32, 1> +// CHECK: } +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<1x2x3x4xf32, 1>, %arg1: memref<1x2x3x4xf32, 1>) -> memref<1x2x3x4xf32, 1> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 1> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<1x2x3x4xf32, 1>, memref<1x2x3x4xf32, 1>) outs(%alloc : memref<1x2x3x4xf32, 1>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_1 : f32 + linalg.yield %1 : f32 + } + return %alloc : memref<1x2x3x4xf32, 1> +} + +// ----- + +// CHECK-LABEL: func.func @linalg_generic_update_ignore_inputs_addr_space( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x2x3x4xf32, 2>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x2x3x4xf32, 2>) -> memref<1x2x3x4xf32, 1> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 1> +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%[[VAL_0]], %[[VAL_1]] : memref<1x2x3x4xf32, 2>, memref<1x2x3x4xf32, 2>) outs(%[[VAL_3]] : memref<1x2x3x4xf32, 3>) { +// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32): +// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32 +// CHECK: linalg.yield %[[VAL_7]] : f32 +// CHECK: } +// CHECK: memref.copy %[[VAL_3]], %[[VAL_2]] : memref<1x2x3x4xf32, 3> to memref<1x2x3x4xf32, 1> +// CHECK: return %[[VAL_2]] : memref<1x2x3x4xf32, 1> +// CHECK: } +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @linalg_generic_update_ignore_inputs_addr_space(%arg0: memref<1x2x3x4xf32, 2>, %arg1: memref<1x2x3x4xf32, 2>) -> memref<1x2x3x4xf32, 1> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 1> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<1x2x3x4xf32, 2>, memref<1x2x3x4xf32, 2>) outs(%alloc : memref<1x2x3x4xf32, 1>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_1 : f32 + linalg.yield %1 : f32 + } + return %alloc : memref<1x2x3x4xf32, 1> +} + +// ----- + +// CHECK-LABEL: func.func @linalg_generic_update_ignore_outputs_addr_space( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x2x3x4xf32, 1>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x2x3x4xf32, 1>) -> memref<1x2x3x4xf32, 3> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 3> +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: memref.copy %[[VAL_0]], %[[VAL_3]] : memref<1x2x3x4xf32, 1> to memref<1x2x3x4xf32, 3> +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<1x2x3x4xf32, 3> +// CHECK: memref.copy %[[VAL_1]], %[[VAL_4]] : memref<1x2x3x4xf32, 1> to memref<1x2x3x4xf32, 3> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x2x3x4xf32, 3>, memref<1x2x3x4xf32, 3>) outs(%[[VAL_2]] : memref<1x2x3x4xf32, 3>) { +// CHECK: ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32): +// CHECK: %[[VAL_8:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32 +// CHECK: linalg.yield %[[VAL_8]] : f32 +// CHECK: } +// CHECK: return %[[VAL_2]] : memref<1x2x3x4xf32, 3> +// CHECK: } +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @linalg_generic_update_ignore_outputs_addr_space(%arg0: memref<1x2x3x4xf32, 1>, %arg1: memref<1x2x3x4xf32, 1>) -> memref<1x2x3x4xf32, 3> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 3> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<1x2x3x4xf32, 1>, memref<1x2x3x4xf32, 1>) outs(%alloc : memref<1x2x3x4xf32, 3>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_1 : f32 + linalg.yield %1 : f32 + } + return %alloc : memref<1x2x3x4xf32, 3> +} + +// ----- + +// CHECK-LABEL: func.func @linalg_generic_update_ignore_all( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x2x3x4xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x2x3x4xf32>) -> memref<1x2x3x4xf32, 3> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 3> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%[[VAL_0]], %[[VAL_1]] : memref<1x2x3x4xf32>, memref<1x2x3x4xf32>) outs(%[[VAL_2]] : memref<1x2x3x4xf32, 3>) { +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32): +// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: linalg.yield %[[VAL_6]] : f32 +// CHECK: } +// CHECK: return %[[VAL_2]] : memref<1x2x3x4xf32, 3> +// CHECK: } +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @linalg_generic_update_ignore_all(%arg0: memref<1x2x3x4xf32>, %arg1: memref<1x2x3x4xf32>) -> memref<1x2x3x4xf32, 3> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 3> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<1x2x3x4xf32>, memref<1x2x3x4xf32>) outs(%alloc : memref<1x2x3x4xf32, 3>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_1 : f32 + linalg.yield %1 : f32 + } + return %alloc : memref<1x2x3x4xf32, 3> +} + +// ----- + +// CHECK-LABEL: func.func @linalg_generic_update_only_one_input( +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x4xi32, 1>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<3x4xi32>) -> memref<3x4xi32, 3> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x4xi32, 3> +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<3x4xi32, 3> +// CHECK: memref.copy %[[VAL_0]], %[[VAL_3]] : memref<3x4xi32, 1> to memref<3x4xi32, 3> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_3]], %[[VAL_1]] : memref<3x4xi32, 3>, memref<3x4xi32>) outs(%[[VAL_2]] : memref<3x4xi32, 3>) { +// CHECK: ^bb0(%[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32): +// CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_4]], %[[VAL_5]] : i32 +// CHECK: linalg.yield %[[VAL_7]] : i32 +// CHECK: } +// CHECK: return %[[VAL_2]] : memref<3x4xi32, 3> +// CHECK: } +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @linalg_generic_update_only_one_input(%arg0: memref<3x4xi32, 1>, %arg1: memref<3x4xi32>) -> memref<3x4xi32, 3> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<3x4xi32, 3> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<3x4xi32, 1>, memref<3x4xi32>) outs(%alloc : memref<3x4xi32, 3>) { + ^bb0(%in: i32, %in_1: i32, %out: i32): + %1 = arith.addi %in, %in_1 : i32 + linalg.yield %1 : i32 + } + return %alloc : memref<3x4xi32, 3> +} + +// ----- + +// CHECK-LABEL: func.func @linalg_generic_ignore_dyanmic_shape( +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x4xi32, 1>, +// CHECK-SAME: %[[VAL_1:.*]]: memref) -> memref<3x4xi32, 1> { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x4xi32, 1> +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<3x4xi32, 3> +// CHECK: memref.copy %[[VAL_0]], %[[VAL_3]] : memref<3x4xi32, 1> to memref<3x4xi32, 3> +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<3x4xi32, 3> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_3]], %[[VAL_1]] : memref<3x4xi32, 3>, memref) outs(%[[VAL_4]] : memref<3x4xi32, 3>) { +// CHECK: ^bb0(%[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32): +// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32 +// CHECK: linalg.yield %[[VAL_8]] : i32 +// CHECK: } +// CHECK: memref.copy %[[VAL_4]], %[[VAL_2]] : memref<3x4xi32, 3> to memref<3x4xi32, 1> +// CHECK: return %[[VAL_2]] : memref<3x4xi32, 1> +// CHECK: } +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @linalg_generic_ignore_dyanmic_shape(%arg0: memref<3x4xi32, 1>, %arg1: memref) -> memref<3x4xi32, 1> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<3x4xi32, 1> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<3x4xi32, 1>, memref) outs(%alloc : memref<3x4xi32, 1>) { + ^bb0(%in: i32, %in_1: i32, %out: i32): + %1 = arith.addi %in, %in_1 : i32 + linalg.yield %1 : i32 + } + return %alloc : memref<3x4xi32, 1> +} diff --git a/mlir/test/Dialect/Linalg/update_address_space_same_addr_space.mlir b/mlir/test/Dialect/Linalg/update_address_space_same_addr_space.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/update_address_space_same_addr_space.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(linalg-update-address-space{src-addr-space=3 dest-addr-space=3}))" -verify-diagnostics + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// expected-error @+1 {{Source and destination address spaces must be different}} +func.func @linalg_generic_update_ignore_all(%arg0: memref<1x2x3x4xf32>, %arg1: memref<1x2x3x4xf32>) -> memref<1x2x3x4xf32, 3> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3x4xf32, 3> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"], library_call = ""} ins(%arg0, %arg1 : memref<1x2x3x4xf32>, memref<1x2x3x4xf32>) outs(%alloc : memref<1x2x3x4xf32, 3>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_1 : f32 + linalg.yield %1 : f32 + } + return %alloc : memref<1x2x3x4xf32, 3> +}