diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -47,6 +47,15 @@ LLVM_OpBase { } +//===----------------------------------------------------------------------===// +// NVVM attribute definitions +//===----------------------------------------------------------------------===// + +class NVVM_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + //===----------------------------------------------------------------------===// // NVVM intrinsic operations //===----------------------------------------------------------------------===// @@ -460,12 +469,10 @@ } /// Attribute to hold the MMA shape -def NVVM_MMAShapeAttr : StructAttr<"MMAShapeAttr", NVVM_Dialect, [ - StructFieldAttr<"m", I32Attr>, - StructFieldAttr<"n", I32Attr>, - StructFieldAttr<"k", I32Attr> - ]> { +def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> { let summary = "Attribute for MMA operation shape."; + let parameters = (ins "int":$m, "int":$n, "int":$k); + let assemblyFormat = "`<` struct(params) `>`"; } // Returns true if this combination of layout/satf for MMA ops is supported; @@ -983,7 +990,7 @@ string llvmBuilder = [{ auto operands = moduleTranslation.lookupValues(opInst.getOperands()); auto intId = mlir::NVVM::MmaOp::getIntrinsicID( - $shape.m().getInt(), $shape.n().getInt(), $shape.k().getInt(), + $shape.getM(), $shape.getN(), $shape.getK(), $b1Op, $intOverflowBehavior, $layoutA, $layoutB, $multiplicandAPtxType.getValue(), diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -78,10 +78,17 @@ TupleType getTupleType(TypeRange elementTypes); NoneType getNoneType(); - /// Get or construct an instance of the type 'ty' with provided arguments. + /// Get or construct an instance of the type `Ty` with provided arguments. template - Ty getType(Args... args) { - return Ty::get(context, args...); + Ty getType(Args &&...args) { + return Ty::get(context, std::forward(args)...); + } + + /// Get or construct an instance of the attribute `Attr` with provided + /// arguments. + template + Attr getAttr(Args &&...args) { + return Attr::get(context, std::forward(args)...); } // Attributes. @@ -510,8 +517,7 @@ Operation *cloneWithoutRegions(Operation &op) { return insert(op.cloneWithoutRegions()); } - template - OpT cloneWithoutRegions(OpT op) { + template OpT cloneWithoutRegions(OpT op) { return cast(cloneWithoutRegions(*op.getOperation())); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -196,11 +196,8 @@ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); MLIRContext *ctx = builder.getContext(); - Type i32 = builder.getIntegerType(32); result.addAttribute( - "shape", MMAShapeAttr::get(builder.getIntegerAttr(i32, shape[0]), - builder.getIntegerAttr(i32, shape[1]), - builder.getIntegerAttr(i32, shape[2]), ctx)); + "shape", builder.getAttr(shape[0], shape[1], shape[2])); result.addOperands(operandA); result.addOperands(operandB); @@ -358,9 +355,8 @@ auto s32x2StructTy = LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); - std::array mmaShape{shapeAttr().m().getInt(), - shapeAttr().n().getInt(), - shapeAttr().k().getInt()}; + std::array mmaShape{shapeAttr().getM(), shapeAttr().getN(), + shapeAttr().getK()}; // These variables define the set of allowed data types for matrices A, B, C, // and result. diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -12,7 +12,7 @@ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>> // CHECK-NOT llvm.extractvalue // CHECK: [[d:%.+]] = nvvm.mma.sync - // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32} + // CHECK-SAME: shape = #nvvm.shape %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> @@ -30,7 +30,7 @@ func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { // We just need to check the mma instruction and the manipulatin of the result. // CHECK: [[d:%.+]] = nvvm.mma.sync - // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32} + // CHECK-SAME: shape = #nvvm.shape // CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32> // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32> @@ -61,7 +61,7 @@ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>> // CHECK-NOT llvm.extractvalue // CHECK: [[d:%.+]] = nvvm.mma.sync - // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32} + // CHECK-SAME: shape = #nvvm.shape %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> @@ -95,7 +95,7 @@ // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type - // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32} + // CHECK-SAME: shape = #nvvm.shape %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> return %d : vector<2x2xi32> } @@ -116,7 +116,7 @@ // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type - // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32} + // CHECK-SAME: shape = #nvvm.shape %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32> return %d : vector<2x2xi32> } @@ -143,7 +143,7 @@ // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type - // CHECK-SAME: shape = {k = 64 : i32, m = 16 : i32, n = 8 : i32} + // CHECK-SAME: shape = #nvvm.shape %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32> return %d : vector<2x2xi32> } @@ -156,7 +156,7 @@ // CHECK: llvm.extractvalue // CHECK: llvm.extractvalue // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}] - // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32} + // CHECK-SAME: shape = #nvvm.shape %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> // CHECK: llvm.mlir.undef : vector<2xf64> // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)> @@ -217,7 +217,7 @@ // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}, {{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}, {{%.+}}, {{%.+}}] // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type - // CHECK-SAME: shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32} + // CHECK-SAME: shape = #nvvm.shape // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)> %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32> // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0] diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -541,7 +541,7 @@ %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { // expected-error@+1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}} %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] - {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = #nvvm.shape} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } @@ -553,7 +553,7 @@ %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { // expected-error@+1 {{Could not match allowed types for the result; expected one of !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> but got !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>}} %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] - {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)> + {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)> } @@ -565,7 +565,7 @@ %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { // expected-error@+1 {{op requires attribute 'layoutA'}} %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] - {shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + {shape = #nvvm.shape}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } @@ -575,7 +575,7 @@ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : vector<2xf16>, %c1 : vector<2xf16>) { // expected-error@+1 {{unimplemented variant for MMA shape <8, 8, 16>}} - %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } @@ -588,7 +588,7 @@ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, - shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -85,7 +85,7 @@ // CHECK: nvvm.mma.sync %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } @@ -96,19 +96,19 @@ // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + shape = #nvvm.shape} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> } // CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8 func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32, %c0 : i32, %c1 : i32) { - // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> + // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, intOverflowBehavior=#nvvm.mma_int_overflow, - shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> + shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> llvm.return %0 : !llvm.struct<(i32, i32)> } @@ -119,7 +119,7 @@ // CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } @@ -128,10 +128,10 @@ %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : vector<2xf16>, %c1 : vector<2xf16>) { - // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } @@ -140,10 +140,10 @@ %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : vector<2xf16>, %c1 : vector<2xf16>) { - // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> + shape = #nvvm.shape} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } @@ -152,10 +152,10 @@ %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) { - // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } @@ -164,10 +164,10 @@ %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) { - // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } @@ -175,23 +175,23 @@ func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) { - // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, - shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } // CHECK-LABEL: @nvvm_mma_m16n8k16_s8_s8 func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { - // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, intOverflowBehavior=#nvvm.mma_int_overflow, - shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } @@ -199,12 +199,12 @@ func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { - // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, intOverflowBehavior=#nvvm.mma_int_overflow, - shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } @@ -212,11 +212,11 @@ func.func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, %b0 : i32, %b1 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { - // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, - b1Op = #nvvm.mma_b1op, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + b1Op = #nvvm.mma_b1op, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } @@ -224,12 +224,12 @@ func.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { - // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, b1Op = #nvvm.mma_b1op, - shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } @@ -237,11 +237,11 @@ func.func @nvvm_mma_m8n8k128_b1_b1(%a0 : i32, %b0 : i32, %c0 : i32, %c1 : i32) { - // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> + // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, - b1Op = #nvvm.mma_b1op, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32)> + b1Op = #nvvm.mma_b1op, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32)> } @@ -249,12 +249,12 @@ func.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { - // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, intOverflowBehavior=#nvvm.mma_int_overflow, - shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -107,7 +107,7 @@ %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> { // CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32 %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] - {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } @@ -118,7 +118,7 @@ %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> { // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f16 %0 = nvvm.mma.sync A[ %a0, %a1, %a2, %a3 ] B[ %b0, %b1 ] C[ %c0, %c1 ] - {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } @@ -132,7 +132,7 @@ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16 %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> + shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } @@ -145,7 +145,7 @@ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32 %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } @@ -158,7 +158,7 @@ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f32 %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } @@ -171,7 +171,7 @@ {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, intOverflowBehavior=#nvvm.mma_int_overflow, - shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } @@ -184,7 +184,7 @@ {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, intOverflowBehavior=#nvvm.mma_int_overflow, - shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } @@ -196,7 +196,7 @@ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, - b1Op = #nvvm.mma_b1op, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + b1Op = #nvvm.mma_b1op, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } @@ -209,7 +209,7 @@ {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, intOverflowBehavior=#nvvm.mma_int_overflow, - shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } @@ -220,7 +220,7 @@ // CHECK: call { double, double } @llvm.nvvm.mma.m8n8k4.row.col.f64 %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (f64, f64, f64) -> !llvm.struct<(f64, f64)> + shape = #nvvm.shape} : (f64, f64, f64) -> !llvm.struct<(f64, f64)> llvm.return %0 : !llvm.struct<(f64, f64)> } @@ -232,7 +232,7 @@ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, - shape = {m = 16 : i32, n = 8 : i32, k = 4 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> }