Index: mlir/docs/Traits.md =================================================================== --- mlir/docs/Traits.md +++ mlir/docs/Traits.md @@ -275,3 +275,15 @@ This trait provides verification and functionality for operations that are known to be [terminators](LangRef.md#terminator-operations). + +### MemRefsNaturalizable + +* `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable` + +This trait is used to flag operations that can accomodate `MemRefs` with +nontrivial memory-layout specifications. This trait indicates that the +normalization of memory layout can proceed for such operations. +`MemRefs` normalization consists of replacing the original `MemRefs` +with layout specifications to an equivalent `MemRefs` where the specified +memory layout is applied to the access pattern and the type associated with +this memory references. Index: mlir/include/mlir/Dialect/Affine/IR/AffineOps.h =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -80,8 +80,9 @@ // multiple stride levels (possibly using AffineMaps to specify multiple levels // of striding). // TODO: Consider replacing src/dst memref indices with view memrefs. -class AffineDmaStartOp : public Op { +class AffineDmaStartOp : public Op { public: using Op::Op; @@ -268,8 +269,9 @@ // ... // affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> // -class AffineDmaWaitOp : public Op { +class AffineDmaWaitOp : public Op { public: using Op::Op; Index: mlir/include/mlir/Dialect/Affine/IR/AffineOps.td =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -405,7 +405,8 @@ class AffineLoadOpBase traits = []> : Affine_Op])> { + [DeclareOpInterfaceMethods, + MemRefsNormalizable])> { let arguments = (ins Arg:$memref, Variadic:$indices); @@ -723,7 +724,8 @@ class AffineStoreOpBase traits = []> : Affine_Op])> { + [DeclareOpInterfaceMethods, + MemRefsNormalizable])> { code extraClassDeclarationBase = [{ /// Returns the operand index of the value to be stored. unsigned getStoredValOperandIndex() { return 0; } Index: mlir/include/mlir/Dialect/StandardOps/IR/Ops.td =================================================================== --- mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -658,7 +658,7 @@ // CallOp //===----------------------------------------------------------------------===// -def CallOp : Std_Op<"call", [CallOpInterface]> { +def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> { let summary = "call operation"; let description = [{ The `call` operation represents a direct call to a function that is within @@ -1388,7 +1388,8 @@ // DeallocOp //===----------------------------------------------------------------------===// -def DeallocOp : Std_Op<"dealloc", [MemoryEffects<[MemFree]>]> { +def DeallocOp : Std_Op<"dealloc", + [MemoryEffects<[MemFree]>, MemRefsNormalizable]> { let summary = "memory deallocation operation"; let description = [{ The `dealloc` operation frees the region of memory referenced by a memref @@ -2124,8 +2125,8 @@ // ReturnOp //===----------------------------------------------------------------------===// -def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, ReturnLike, - Terminator]> { +def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, + MemRefsNormalizable, ReturnLike, Terminator]> { let summary = "return operation"; let description = [{ The `return` operation represents a return operation within a function. Index: mlir/include/mlir/IR/OpBase.td =================================================================== --- mlir/include/mlir/IR/OpBase.td +++ mlir/include/mlir/IR/OpBase.td @@ -1698,6 +1698,8 @@ NativeOpTrait<"SameOperandsAndResultElementType">; // Op is a terminator. def Terminator : NativeOpTrait<"IsTerminator">; +// Op can be safely normalized in presence of MemRefs with maps. +def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator Index: mlir/include/mlir/IR/OpDefinition.h =================================================================== --- mlir/include/mlir/IR/OpDefinition.h +++ mlir/include/mlir/IR/OpDefinition.h @@ -1212,6 +1212,13 @@ } }; +/// This trait indicates that memref normalization pass can normalize nontrivial +/// memory layout specification for operations of this type. +template +struct MemRefsNormalizable + : public TraitBase { +}; + } // end namespace OpTrait //===----------------------------------------------------------------------===// Index: mlir/lib/Transforms/NormalizeMemRefs.cpp =================================================================== --- mlir/lib/Transforms/NormalizeMemRefs.cpp +++ mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -106,23 +106,15 @@ normalizeFuncOpMemRefs(funcOp, moduleOp); } -/// Return true if this operation dereferences one or more memref's. -/// TODO: Temporary utility, will be replaced when this is modeled through -/// side-effects/op traits. -static bool isMemRefDereferencingOp(Operation &op) { - return isa(op); -} - /// Check whether all the uses of oldMemRef are either dereferencing uses or the /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints /// are satisfied will the value become a candidate for replacement. /// TODO: Extend this for DimOps. static bool isMemRefNormalizable(Value::user_range opUsers) { if (llvm::any_of(opUsers, [](Operation *op) { - if (isMemRefDereferencingOp(*op)) + if (op->hasTrait()) return false; - return !isa(*op); + return true; })) return false; return true; Index: mlir/lib/Transforms/Utils/Utils.cpp =================================================================== --- mlir/lib/Transforms/Utils/Utils.cpp +++ mlir/lib/Transforms/Utils/Utils.cpp @@ -279,7 +279,7 @@ // Currently we support the following non-dereferencing ops to be a // candidate for replacement: Dealloc, CallOp and ReturnOp. // TODO: Add support for other kinds of ops. - if (!isa(*op)) + if (! op->hasTrait()) return failure(); } Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -618,6 +618,16 @@ let arguments = (ins I32, OptionalAttr:$optional_attr); let results = (outs I32); } + +// Test for memrefs normalization of an op with normalizable memrefs. +def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> { + let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); +} +def OpNonNorm : TEST_Op<"op_nonnorm"> { + let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); +} + + // Pattern add the argument plus a increasing static number hidden in // OpMTest function. That value is set into the optional argument. // That way, we will know if operations is called once or twice. Index: mlir/test/mlir-tblgen/op-memrefs-norm.mlir =================================================================== --- /dev/null +++ mlir/test/mlir-tblgen/op-memrefs-norm.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt -test-patterns -normalize-memrefs %s | FileCheck %s + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 64, d2 mod 32, d3 mod 64)> + +// CHECK-LABEL: test_norm +func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { + %0 = alloc() : memref<1x16x14x14xf32, #map0> + "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map0> + + // CHECK: %0 = alloc() : memref<1x16x1x1x32x64xf32> + // CHECK: "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> () + // CHECK: dealloc %0 : memref<1x16x1x1x32x64xf32> + return +} + +// CHECK-LABEL: test_nonnorm +func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { + %0 = alloc() : memref<1x16x14x14xf32, #map0> + "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map0> + + // CHECK: %0 = alloc() : memref<1x16x14x14xf32, #map0> + // CHECK: "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () + // CHECK: dealloc %0 : memref<1x16x14x14xf32, #map0> + return +} + +// CHECK-LABEL: test_norm_mix +func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () { + %0 = alloc() : memref<1x16x14x14xf32, #map0> + "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map0> + + // CHECK: %0 = alloc() : memref<1x16x1x1x32x64xf32> + // CHECK: "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> () + // CHECK: dealloc %0 : memref<1x16x1x1x32x64xf32> + return +}