diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -247,6 +247,18 @@ This trait is an important structural property of the IR, and enables operations to have [passes](PassManagement.md) scheduled under them. +### MemRefsNormalizable + +* `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable` + +This trait is used to flag operations that can accommodate `MemRefs` with +non-identity memory-layout specifications. This trait indicates that the +normalization of memory layout can be performed for such operations. +`MemRefs` normalization consists of replacing an original memory reference +with layout specifications to an equivalent memory reference where +the specified memory layout is applied by rewritting accesses and types +associated with that memory reference. + ### Single Block with Implicit Terminator * `OpTrait::SingleBlockImplicitTerminator` : diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -80,8 +80,9 @@ // multiple stride levels (possibly using AffineMaps to specify multiple levels // of striding). // TODO: Consider replacing src/dst memref indices with view memrefs. -class AffineDmaStartOp : public Op { +class AffineDmaStartOp + : public Op { public: using Op::Op; @@ -268,8 +269,9 @@ // ... // affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> // -class AffineDmaWaitOp : public Op { +class AffineDmaWaitOp + : public Op { public: using Op::Op; diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -405,7 +405,8 @@ class AffineLoadOpBase traits = []> : Affine_Op])> { + [DeclareOpInterfaceMethods, + MemRefsNormalizable])> { let arguments = (ins Arg:$memref, Variadic:$indices); @@ -732,7 +733,8 @@ class AffineStoreOpBase traits = []> : Affine_Op])> { + [DeclareOpInterfaceMethods, + MemRefsNormalizable])> { code extraClassDeclarationBase = [{ /// Returns the operand index of the value to be stored. unsigned getStoredValOperandIndex() { return 0; } diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -658,7 +658,7 @@ // CallOp //===----------------------------------------------------------------------===// -def CallOp : Std_Op<"call", [CallOpInterface]> { +def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> { let summary = "call operation"; let description = [{ The `call` operation represents a direct call to a function that is within @@ -1388,7 +1388,8 @@ // DeallocOp //===----------------------------------------------------------------------===// -def DeallocOp : Std_Op<"dealloc", [MemoryEffects<[MemFree]>]> { +def DeallocOp : Std_Op<"dealloc", + [MemoryEffects<[MemFree]>, MemRefsNormalizable]> { let summary = "memory deallocation operation"; let description = [{ The `dealloc` operation frees the region of memory referenced by a memref @@ -2144,8 +2145,8 @@ // ReturnOp //===----------------------------------------------------------------------===// -def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, ReturnLike, - Terminator]> { +def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, + MemRefsNormalizable, ReturnLike, Terminator]> { let summary = "return operation"; let description = [{ The `return` operation represents a return operation within a function. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1698,6 +1698,9 @@ NativeOpTrait<"SameOperandsAndResultElementType">; // Op is a terminator. def Terminator : NativeOpTrait<"IsTerminator">; +// Op can be safely normalized in the presence of MemRefs with +// non-identity maps. +def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1212,6 +1212,20 @@ } }; +/// This trait is used to flag operations that can accommodate MemRefs with +/// non-identity memory-layout specifications. This trait indicates that the +/// normalization of memory layout can be performed for such operations. +/// MemRefs normalization consists of replacing an original memory reference +/// with layout specifications to an equivalent memory reference where the +/// specified memory layout is applied by rewritting accesses and types +/// associated with that memory reference. +// TODO: Right now, the operands of an operation are either all normalizable, +// or not. In the future, we may want to allow some of the operands to be +// normalizable. +template +struct MemRefsNormalizable + : public TraitBase {}; + } // end namespace OpTrait //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -106,23 +106,15 @@ normalizeFuncOpMemRefs(funcOp, moduleOp); } -/// Return true if this operation dereferences one or more memref's. -/// TODO: Temporary utility, will be replaced when this is modeled through -/// side-effects/op traits. -static bool isMemRefDereferencingOp(Operation &op) { - return isa(op); -} - /// Check whether all the uses of oldMemRef are either dereferencing uses or the /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints /// are satisfied will the value become a candidate for replacement. /// TODO: Extend this for DimOps. static bool isMemRefNormalizable(Value::user_range opUsers) { if (llvm::any_of(opUsers, [](Operation *op) { - if (isMemRefDereferencingOp(*op)) + if (op->hasTrait()) return false; - return !isa(*op); + return true; })) return false; return true; diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -279,7 +279,7 @@ // Currently we support the following non-dereferencing ops to be a // candidate for replacement: Dealloc, CallOp and ReturnOp. // TODO: Add support for other kinds of ops. - if (!isa(*op)) + if (!op->hasTrait()) return failure(); } diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt -normalize-memrefs %s | FileCheck %s + +// For all these cases, we test if MemRefs Normalization works with the test +// operations. +// * test.op_norm: this operation has the MemRefsNormalizable attribute. The tests +// that include this operation are constructed so that the normalization should +// happen. +// * test_op_nonnorm: this operation does not have the MemRefsNormalization +// attribute. The tests that include this operation are contructed so that the +// normalization should not happen. + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 64, d2 mod 32, d3 mod 64)> + +// Test with op_norm and maps in arguments and in the operations in the function. + +// CHECK-LABEL: test_norm +// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>) +func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { + %0 = alloc() : memref<1x16x14x14xf32, #map0> + "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map0> + + // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32> + // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> () + // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32> + return +} + +// Same test with op_nonnorm, with maps in the argmentets and the operations in the function. + +// CHECK-LABEL: test_nonnorm +// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32, #map0>) +func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { + %0 = alloc() : memref<1x16x14x14xf32, #map0> + "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map0> + + // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32, #map0> + // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () + // CHECK: dealloc %[[v0]] : memref<1x16x14x14xf32, #map0> + return +} + +// Test with op_norm, with maps in the operations in the function. + +// CHECK-LABEL: test_norm_mix +// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32> +func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () { + %0 = alloc() : memref<1x16x14x14xf32, #map0> + "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map0> + + // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32> + // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> () + // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32> + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -618,6 +618,16 @@ let arguments = (ins I32, OptionalAttr:$optional_attr); let results = (outs I32); } + +// Test for memrefs normalization of an op with normalizable memrefs. +def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> { + let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); +} +// Test for memrefs normalization of an op without normalizable memrefs. +def OpNonNorm : TEST_Op<"op_nonnorm"> { + let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); +} + // Pattern add the argument plus a increasing static number hidden in // OpMTest function. That value is set into the optional argument. // That way, we will know if operations is called once or twice.