|
4 | 4 | from mlir.ir import *
|
5 | 5 | import mlir.dialects.arith as arith
|
6 | 6 | import mlir.dialects.func as func
|
| 7 | +from array import array |
7 | 8 |
|
8 | 9 |
|
9 | 10 | def run(f):
|
@@ -92,3 +93,40 @@ def __str__(self):
|
92 | 93 | b = a * a
|
93 | 94 | # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
|
94 | 95 | print(b)
|
| 96 | + |
| 97 | + |
| 98 | +# CHECK-LABEL: TEST: testArrayConstantConstruction |
| 99 | +@run |
| 100 | +def testArrayConstantConstruction(): |
| 101 | + with Context(), Location.unknown(): |
| 102 | + module = Module.create() |
| 103 | + with InsertionPoint(module.body): |
| 104 | + i32_array = array("i", [1, 2, 3, 4]) |
| 105 | + i32 = IntegerType.get_signless(32) |
| 106 | + vec_i32 = VectorType.get([2, 2], i32) |
| 107 | + arith.constant(vec_i32, i32_array) |
| 108 | + arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32)) |
| 109 | + |
| 110 | + i64_array = array("l", [5, 6, 7, 8]) |
| 111 | + i64 = IntegerType.get_signless(64) |
| 112 | + vec_i64 = VectorType.get([1, 4], i64) |
| 113 | + arith.constant(vec_i64, i64_array) |
| 114 | + arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64)) |
| 115 | + |
| 116 | + f32_array = array("f", [1.0, 2.0, 3.0, 4.0]) |
| 117 | + f32 = F32Type.get() |
| 118 | + vec_f32 = VectorType.get([4, 1], f32) |
| 119 | + arith.constant(vec_f32, f32_array) |
| 120 | + arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32)) |
| 121 | + |
| 122 | + f64_array = array("d", [1.0, 2.0, 3.0, 4.0]) |
| 123 | + f64 = F64Type.get() |
| 124 | + vec_f64 = VectorType.get([2, 1, 2], f64) |
| 125 | + arith.constant(vec_f64, f64_array) |
| 126 | + arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64)) |
| 127 | + |
| 128 | + # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32> |
| 129 | + # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64> |
| 130 | + # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32> |
| 131 | + # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64> |
| 132 | + print(module) |
0 commit comments